Skip to content

Commit 9db1f1e

Browse files
authored
align (#30)
1 parent 90e8d6c commit 9db1f1e

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

train_img.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,6 @@ def main(args):
103103
dtype = torch.bfloat16
104104
elif args.mixed_precision == "fp16":
105105
dtype = torch.float16
106-
elif args.mixed_precision == "fp32":
107-
dtype = torch.float32
108106
else:
109107
raise ValueError(f"Unknown mixed precision {args.mixed_precision}")
110108
model: DiT = (
@@ -283,7 +281,7 @@ def main(args):
283281
parser.add_argument("--log-every", type=int, default=10)
284282
parser.add_argument("--ckpt-every", type=int, default=1000)
285283

286-
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
284+
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["bf16", "fp16"])
287285
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
288286
parser.add_argument("--lr", type=float, default=1e-4, help="Gradient clipping value")
289287
parser.add_argument("--grad_checkpoint", action="store_true", help="Use gradient checkpointing")

train_video.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import colossalai
88
import torch
9+
import torch.distributed as dist
910
from colossalai.booster import Booster
1011
from colossalai.booster.plugin import LowLevelZeroPlugin
1112
from colossalai.cluster import DistCoordinator
@@ -49,6 +50,7 @@ def main(args):
4950
model_string_name = args.model.replace("/", "-")
5051
# Create an experiment folder
5152
experiment_dir = f"{args.outputs}/{experiment_index:03d}-{model_string_name}"
53+
dist.barrier()
5254
if coordinator.is_master():
5355
os.makedirs(experiment_dir, exist_ok=True)
5456
with open(f"{experiment_dir}/config.txt", "w") as f:
@@ -97,7 +99,12 @@ def main(args):
9799

98100
# Create model
99101
img_size = dataset[0][0].shape[-1]
100-
dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
102+
if args.mixed_precision == "bf16":
103+
dtype = torch.bfloat16
104+
elif args.mixed_precision == "fp16":
105+
dtype = torch.float16
106+
else:
107+
raise ValueError(f"Unknown mixed precision {args.mixed_precision}")
101108
model: DiT = (
102109
DiT_models[args.model](
103110
input_size=img_size,
@@ -196,11 +203,15 @@ def main(args):
196203

197204
# Log loss values:
198205
all_reduce_mean(loss)
199-
if coordinator.is_master() and (step + 1) % args.log_every == 0:
200-
pbar.set_postfix({"loss": loss.item()})
201-
writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step)
206+
global_step = epoch * num_steps_per_epoch + step
207+
pbar.set_postfix({"loss": loss.item(), "step": step, "global_step": global_step})
208+
209+
# Log to tensorboard
210+
if coordinator.is_master() and (global_step + 1) % args.log_every == 0:
211+
writer.add_scalar("loss", loss.item(), global_step)
202212

203-
if args.ckpt_every > 0 and (step + 1) % args.ckpt_every == 0:
213+
# Save checkpoint
214+
if args.ckpt_every > 0 and (global_step + 1) % args.ckpt_every == 0:
204215
logger.info(f"Saving checkpoint")
205216
save(
206217
booster,
@@ -210,12 +221,15 @@ def main(args):
210221
lr_scheduler,
211222
epoch,
212223
step + 1,
224+
global_step + 1,
213225
args.batch_size,
214226
coordinator,
215227
experiment_dir,
216228
ema_shape_dict,
217229
)
218-
logger.info(f"Saved checkpoint at epoch {epoch} step {step + 1} to {experiment_dir}")
230+
logger.info(
231+
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {experiment_dir}"
232+
)
219233

220234
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
221235
dataloader.sampler.set_start_index(0)
@@ -242,7 +256,7 @@ def main(args):
242256
parser.add_argument("--batch-size", type=int, default=2)
243257
parser.add_argument("--global-seed", type=int, default=42)
244258
parser.add_argument("--num-workers", type=int, default=4)
245-
parser.add_argument("--log-every", type=int, default=50)
259+
parser.add_argument("--log-every", type=int, default=10)
246260
parser.add_argument("--ckpt-every", type=int, default=1000)
247261
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["bf16", "fp16"])
248262
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")

0 commit comments

Comments
 (0)