66
77import colossalai
88import torch
9+ import torch .distributed as dist
910from colossalai .booster import Booster
1011from colossalai .booster .plugin import LowLevelZeroPlugin
1112from colossalai .cluster import DistCoordinator
@@ -49,6 +50,7 @@ def main(args):
4950 model_string_name = args .model .replace ("/" , "-" )
5051 # Create an experiment folder
5152 experiment_dir = f"{ args .outputs } /{ experiment_index :03d} -{ model_string_name } "
53+ dist .barrier ()
5254 if coordinator .is_master ():
5355 os .makedirs (experiment_dir , exist_ok = True )
5456 with open (f"{ experiment_dir } /config.txt" , "w" ) as f :
@@ -97,7 +99,12 @@ def main(args):
9799
98100 # Create model
99101 img_size = dataset [0 ][0 ].shape [- 1 ]
100- dtype = torch .float16 if args .mixed_precision == "fp16" else torch .bfloat16
102+ if args .mixed_precision == "bf16" :
103+ dtype = torch .bfloat16
104+ elif args .mixed_precision == "fp16" :
105+ dtype = torch .float16
106+ else :
107+ raise ValueError (f"Unknown mixed precision { args .mixed_precision } " )
101108 model : DiT = (
102109 DiT_models [args .model ](
103110 input_size = img_size ,
@@ -196,11 +203,15 @@ def main(args):
196203
197204 # Log loss values:
198205 all_reduce_mean (loss )
199- if coordinator .is_master () and (step + 1 ) % args .log_every == 0 :
200- pbar .set_postfix ({"loss" : loss .item ()})
201- writer .add_scalar ("loss" , loss .item (), epoch * num_steps_per_epoch + step )
206+ global_step = epoch * num_steps_per_epoch + step
207+ pbar .set_postfix ({"loss" : loss .item (), "step" : step , "global_step" : global_step })
208+
209+ # Log to tensorboard
210+ if coordinator .is_master () and (global_step + 1 ) % args .log_every == 0 :
211+ writer .add_scalar ("loss" , loss .item (), global_step )
202212
203- if args .ckpt_every > 0 and (step + 1 ) % args .ckpt_every == 0 :
213+ # Save checkpoint
214+ if args .ckpt_every > 0 and (global_step + 1 ) % args .ckpt_every == 0 :
204215 logger .info (f"Saving checkpoint" )
205216 save (
206217 booster ,
@@ -210,12 +221,15 @@ def main(args):
210221 lr_scheduler ,
211222 epoch ,
212223 step + 1 ,
224+ global_step + 1 ,
213225 args .batch_size ,
214226 coordinator ,
215227 experiment_dir ,
216228 ema_shape_dict ,
217229 )
218- logger .info (f"Saved checkpoint at epoch { epoch } step { step + 1 } to { experiment_dir } " )
230+ logger .info (
231+ f"Saved checkpoint at epoch { epoch } step { step + 1 } global_step { global_step + 1 } to { experiment_dir } "
232+ )
219233
220234 # the continue epochs are not resumed, so we need to reset the sampler start index and start step
221235 dataloader .sampler .set_start_index (0 )
@@ -242,7 +256,7 @@ def main(args):
242256 parser .add_argument ("--batch-size" , type = int , default = 2 )
243257 parser .add_argument ("--global-seed" , type = int , default = 42 )
244258 parser .add_argument ("--num-workers" , type = int , default = 4 )
245- parser .add_argument ("--log-every" , type = int , default = 50 )
259+ parser .add_argument ("--log-every" , type = int , default = 10 )
246260 parser .add_argument ("--ckpt-every" , type = int , default = 1000 )
247261 parser .add_argument ("--mixed_precision" , type = str , default = "bf16" , choices = ["bf16" , "fp16" ])
248262 parser .add_argument ("--grad_clip" , type = float , default = 1.0 , help = "Gradient clipping value" )
0 commit comments