@@ -140,33 +140,43 @@ def run_optimizer(self):
140140 self .optimizer .step ()
141141
142142 def start_training (self ):
143- times = []
144- last_time = time .time ()
145- step = 0
146- while True :
147- if self .global_step >= self .args .max_train_steps :
148- xm .mark_step ()
149- break
150- if step == 4 and PROFILE_DIR is not None :
151- xm .wait_device_ops ()
152- xp .trace_detached (f"localhost:{ PORT } " , PROFILE_DIR , duration_ms = args .profile_duration )
143+ dataloader_exception = False
144+ measure_start_step = args .measure_start_step
145+ assert measure_start_step < self .args .max_train_steps
146+ total_time = 0
147+ for step in range (0 , self .args .max_train_steps ):
153148 try :
154149 batch = next (self .dataloader )
155150 except Exception as e :
151+ dataloader_exception = True
156152 print (e )
157153 break
154+ if step == measure_start_step and PROFILE_DIR is not None :
155+ xm .wait_device_ops ()
156+ xp .trace_detached (f"localhost:{ PORT } " , PROFILE_DIR , duration_ms = args .profile_duration )
157+ last_time = time .time ()
158158 loss = self .step_fn (batch ["pixel_values" ], batch ["input_ids" ])
159- step_time = time .time () - last_time
160- if step >= 10 :
161- times .append (step_time )
162- print (f"step: { step } , step_time: { step_time } " )
163- if step % 5 == 0 :
164- print (f"step: { step } , loss: { loss } " )
165- last_time = time .time ()
166159 self .global_step += 1
167- step += 1
168- # print(f"Average step time: {sum(times)/len(times)}")
169- xm .wait_device_ops ()
160+
161+ def print_loss_closure (step , loss ):
162+ print (f"Step: { step } , Loss: { loss } " )
163+
164+ if args .print_loss :
165+ xm .add_step_closure (
166+ print_loss_closure ,
167+ args = (
168+ self .global_step ,
169+ loss ,
170+ ),
171+ )
172+ xm .mark_step ()
173+ if not dataloader_exception :
174+ xm .wait_device_ops ()
175+ total_time = time .time () - last_time
176+ print (f"Average step time: { total_time / (self .args .max_train_steps - measure_start_step )} " )
177+ else :
178+ print ("dataloader exception happen, skip result" )
179+ return
170180
171181 def step_fn (
172182 self ,
@@ -180,7 +190,10 @@ def step_fn(
180190 noise = torch .randn_like (latents ).to (self .device , dtype = self .weight_dtype )
181191 bsz = latents .shape [0 ]
182192 timesteps = torch .randint (
183- 0 , self .noise_scheduler .config .num_train_timesteps , (bsz ,), device = latents .device
193+ 0 ,
194+ self .noise_scheduler .config .num_train_timesteps ,
195+ (bsz ,),
196+ device = latents .device ,
184197 )
185198 timesteps = timesteps .long ()
186199
@@ -224,9 +237,6 @@ def step_fn(
224237
225238def parse_args ():
226239 parser = argparse .ArgumentParser (description = "Simple example of a training script." )
227- parser .add_argument (
228- "--input_perturbation" , type = float , default = 0 , help = "The scale of input perturbation. Recommended 0.1."
229- )
230240 parser .add_argument ("--profile_duration" , type = int , default = 10000 , help = "Profile duration in ms" )
231241 parser .add_argument (
232242 "--pretrained_model_name_or_path" ,
@@ -258,12 +268,6 @@ def parse_args():
258268 " or to a folder containing files that 🤗 Datasets can understand."
259269 ),
260270 )
261- parser .add_argument (
262- "--dataset_config_name" ,
263- type = str ,
264- default = None ,
265- help = "The config of the Dataset, leave as None if there's only one config." ,
266- )
267271 parser .add_argument (
268272 "--train_data_dir" ,
269273 type = str ,
@@ -283,15 +287,6 @@ def parse_args():
283287 default = "text" ,
284288 help = "The column of the dataset containing a caption or a list of captions." ,
285289 )
286- parser .add_argument (
287- "--max_train_samples" ,
288- type = int ,
289- default = None ,
290- help = (
291- "For debugging purposes or quicker training, truncate the number of training examples to this "
292- "value if set."
293- ),
294- )
295290 parser .add_argument (
296291 "--output_dir" ,
297292 type = str ,
@@ -304,7 +299,6 @@ def parse_args():
304299 default = None ,
305300 help = "The directory where the downloaded models and datasets will be stored." ,
306301 )
307- parser .add_argument ("--seed" , type = int , default = None , help = "A seed for reproducible training." )
308302 parser .add_argument (
309303 "--resolution" ,
310304 type = int ,
@@ -374,12 +368,19 @@ def parse_args():
374368 default = 1 ,
375369 help = ("Number of subprocesses to use for data loading to cpu." ),
376370 )
371+ parser .add_argument (
372+ "--loader_prefetch_factor" ,
373+ type = int ,
374+ default = 2 ,
375+ help = ("Number of batches loaded in advance by each worker." ),
376+ )
377377 parser .add_argument (
378378 "--device_prefetch_size" ,
379379 type = int ,
380380 default = 1 ,
381381 help = ("Number of subprocesses to use for data loading to tpu from cpu. " ),
382382 )
383+ parser .add_argument ("--measure_start_step" , type = int , default = 10 , help = "Step to start profiling." )
383384 parser .add_argument ("--adam_beta1" , type = float , default = 0.9 , help = "The beta1 parameter for the Adam optimizer." )
384385 parser .add_argument ("--adam_beta2" , type = float , default = 0.999 , help = "The beta2 parameter for the Adam optimizer." )
385386 parser .add_argument ("--adam_weight_decay" , type = float , default = 1e-2 , help = "Weight decay to use." )
@@ -394,12 +395,8 @@ def parse_args():
394395 "--mixed_precision" ,
395396 type = str ,
396397 default = None ,
397- choices = ["no" , "fp16" , "bf16" ],
398- help = (
399- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
400- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
401- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
402- ),
398+ choices = ["no" , "bf16" ],
399+ help = ("Whether to use mixed precision. Bf16 requires PyTorch >= 1.10" ),
403400 )
404401 parser .add_argument ("--push_to_hub" , action = "store_true" , help = "Whether or not to push the model to the Hub." )
405402 parser .add_argument ("--hub_token" , type = str , default = None , help = "The token to use to push to the Model Hub." )
@@ -409,6 +406,12 @@ def parse_args():
409406 default = None ,
410407 help = "The name of the repository to keep in sync with the local `output_dir`." ,
411408 )
409+ parser .add_argument (
410+ "--print_loss" ,
411+ default = False ,
412+ action = "store_true" ,
413+ help = ("Print loss at every step." ),
414+ )
412415
413416 args = parser .parse_args ()
414417
@@ -436,7 +439,6 @@ def load_dataset(args):
436439 # Downloading and loading a dataset from the hub.
437440 dataset = datasets .load_dataset (
438441 args .dataset_name ,
439- args .dataset_config_name ,
440442 cache_dir = args .cache_dir ,
441443 data_dir = args .train_data_dir ,
442444 )
@@ -481,9 +483,7 @@ def main(args):
481483 _ = xp .start_server (PORT )
482484
483485 num_devices = xr .global_runtime_device_count ()
484- device_ids = np .arange (num_devices )
485- mesh_shape = (num_devices , 1 )
486- mesh = xs .Mesh (device_ids , mesh_shape , ("x" , "y" ))
486+ mesh = xs .get_1d_mesh ("data" )
487487 xs .set_global_mesh (mesh )
488488
489489 text_encoder = CLIPTextModel .from_pretrained (
@@ -520,6 +520,7 @@ def main(args):
520520 from torch_xla .distributed .fsdp .utils import apply_xla_patch_to_nn_linear
521521
522522 unet = apply_xla_patch_to_nn_linear (unet , xs .xla_patched_nn_linear_forward )
523+ unet .enable_xla_flash_attention (partition_spec = ("data" , None , None , None ))
523524
524525 vae .requires_grad_ (False )
525526 text_encoder .requires_grad_ (False )
@@ -530,15 +531,12 @@ def main(args):
530531 # as these weights are only used for inference, keeping weights in full
531532 # precision is not required.
532533 weight_dtype = torch .float32
533- if args .mixed_precision == "fp16" :
534- weight_dtype = torch .float16
535- elif args .mixed_precision == "bf16" :
534+ if args .mixed_precision == "bf16" :
536535 weight_dtype = torch .bfloat16
537536
538537 device = xm .xla_device ()
539- print ("device: " , device )
540- print ("weight_dtype: " , weight_dtype )
541538
539+ # Move text_encode and vae to device and cast to weight_dtype
542540 text_encoder = text_encoder .to (device , dtype = weight_dtype )
543541 vae = vae .to (device , dtype = weight_dtype )
544542 unet = unet .to (device , dtype = weight_dtype )
@@ -606,24 +604,27 @@ def collate_fn(examples):
606604 collate_fn = collate_fn ,
607605 num_workers = args .dataloader_num_workers ,
608606 batch_size = args .train_batch_size ,
607+ prefetch_factor = args .loader_prefetch_factor ,
609608 )
610609
611610 train_dataloader = pl .MpDeviceLoader (
612611 train_dataloader ,
613612 device ,
614613 input_sharding = {
615- "pixel_values" : xs .ShardingSpec (mesh , ("x " , None , None , None ), minibatch = True ),
616- "input_ids" : xs .ShardingSpec (mesh , ("x " , None ), minibatch = True ),
614+ "pixel_values" : xs .ShardingSpec (mesh , ("data " , None , None , None ), minibatch = True ),
615+ "input_ids" : xs .ShardingSpec (mesh , ("data " , None ), minibatch = True ),
617616 },
618617 loader_prefetch_size = args .loader_prefetch_size ,
619618 device_prefetch_size = args .device_prefetch_size ,
620619 )
621620
621+ num_hosts = xr .process_count ()
622+ num_devices_per_host = num_devices // num_hosts
622623 if xm .is_master_ordinal ():
623624 print ("***** Running training *****" )
624- print (f"Instantaneous batch size per device = { args .train_batch_size } " )
625+ print (f"Instantaneous batch size per device = { args .train_batch_size // num_devices_per_host } " )
625626 print (
626- f"Total train batch size (w. parallel, distributed & accumulation) = { args .train_batch_size * num_devices } "
627+ f"Total train batch size (w. parallel, distributed & accumulation) = { args .train_batch_size * num_hosts } "
627628 )
628629 print (f" Total optimization steps = { args .max_train_steps } " )
629630
0 commit comments