11import argparse
22import os
33import random
4-
54import time
65from pathlib import Path
76
2928from diffusers .utils import is_wandb_available
3029from diffusers .utils .hub_utils import load_or_create_model_card , populate_model_card
3130
31+
3232if is_wandb_available ():
3333 pass
3434
35- PROFILE_DIR = os .environ .get (' PROFILE_DIR' , None )
36- CACHE_DIR = os .environ .get (' CACHE_DIR' , None )
35+ PROFILE_DIR = os .environ .get (" PROFILE_DIR" , None )
36+ CACHE_DIR = os .environ .get (" CACHE_DIR" , None )
3737if CACHE_DIR :
3838 xr .initialize_cache (CACHE_DIR , readonly = False )
3939xr .use_spmd ()
@@ -151,12 +151,24 @@ def start_training(self):
151151 dataloader_exception = True
152152 print (e )
153153 break
154- if step == measure_start_step and PROFILE_DIR is not None :
154+ if step == measure_start_step and PROFILE_DIR is not None :
155155 xm .wait_device_ops ()
156- xp .trace_detached (' localhost:9012' , PROFILE_DIR , duration_ms = args .profile_duration )
157- last_time = time .time ()
156+ xp .trace_detached (" localhost:9012" , PROFILE_DIR , duration_ms = args .profile_duration )
157+ last_time = time .time ()
158158 loss = self .step_fn (batch ["pixel_values" ], batch ["input_ids" ])
159159 self .global_step += 1
160+
161+ def print_loss_closure (step , loss ):
162+ print (f"Step: { step } , Loss: { loss } " )
163+
164+ if args .print_loss :
165+ xm .add_step_closure (
166+ print_loss_closure ,
167+ args = (
168+ self .global_step ,
169+ loss ,
170+ ),
171+ )
160172 xm .mark_step ()
161173 if not dataloader_exception :
162174 xm .wait_device_ops ()
@@ -170,7 +182,7 @@ def step_fn(
170182 self ,
171183 pixel_values ,
172184 input_ids ,
173- ):
185+ ):
174186 with xp .Trace ("model.forward" ):
175187 self .optimizer .zero_grad ()
176188 latents = self .vae .encode (pixel_values ).latent_dist .sample ()
@@ -196,12 +208,8 @@ def step_fn(
196208 elif self .noise_scheduler .config .prediction_type == "v_prediction" :
197209 target = self .noise_scheduler .get_velocity (latents , noise , timesteps )
198210 else :
199- raise ValueError (
200- f"Unknown prediction type { self .noise_scheduler .config .prediction_type } "
201- )
202- model_pred = self .unet (
203- noisy_latents , timesteps , encoder_hidden_states , return_dict = False
204- )[0 ]
211+ raise ValueError (f"Unknown prediction type { self .noise_scheduler .config .prediction_type } " )
212+ model_pred = self .unet (noisy_latents , timesteps , encoder_hidden_states , return_dict = False )[0 ]
205213 with xp .Trace ("model.backward" ):
206214 if self .args .snr_gamma is None :
207215 loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
@@ -210,9 +218,9 @@ def step_fn(
210218 # Since we predict the noise instead of x_0, the original formulation is slightly changed.
211219 # This is discussed in Section 4.2 of the same paper.
212220 snr = compute_snr (self .noise_scheduler , timesteps )
213- mse_loss_weights = torch .stack (
214- [ snr , self . args . snr_gamma * torch . ones_like ( timesteps )], dim = 1
215- ). min ( dim = 1 ) [0 ]
221+ mse_loss_weights = torch .stack ([ snr , self . args . snr_gamma * torch . ones_like ( timesteps )], dim = 1 ). min (
222+ dim = 1
223+ )[0 ]
216224 if self .noise_scheduler .config .prediction_type == "epsilon" :
217225 mse_loss_weights = mse_loss_weights / snr
218226 elif self .noise_scheduler .config .prediction_type == "v_prediction" :
@@ -226,11 +234,10 @@ def step_fn(
226234 self .run_optimizer ()
227235 return loss
228236
237+
229238def parse_args ():
230239 parser = argparse .ArgumentParser (description = "Simple example of a training script." )
231- parser .add_argument (
232- "--profile_duration" , type = int , default = 10000 , help = "Profile duration in ms"
233- )
240+ parser .add_argument ("--profile_duration" , type = int , default = 10000 , help = "Profile duration in ms" )
234241 parser .add_argument (
235242 "--pretrained_model_name_or_path" ,
236243 type = str ,
@@ -359,25 +366,19 @@ def parse_args():
359366 "--loader_prefetch_size" ,
360367 type = int ,
361368 default = 1 ,
362- help = (
363- "Number of subprocesses to use for data loading to cpu."
364- ),
369+ help = ("Number of subprocesses to use for data loading to cpu." ),
365370 )
366371 parser .add_argument (
367372 "--loader_prefetch_factor" ,
368373 type = int ,
369374 default = 2 ,
370- help = (
371- "Number of batches loaded in advance by each worker."
372- ),
375+ help = ("Number of batches loaded in advance by each worker." ),
373376 )
374377 parser .add_argument (
375378 "--device_prefetch_size" ,
376379 type = int ,
377380 default = 1 ,
378- help = (
379- "Number of subprocesses to use for data loading to tpu from cpu. "
380- ),
381+ help = ("Number of subprocesses to use for data loading to tpu from cpu. " ),
381382 )
382383 parser .add_argument ("--adam_beta1" , type = float , default = 0.9 , help = "The beta1 parameter for the Adam optimizer." )
383384 parser .add_argument ("--adam_beta2" , type = float , default = 0.999 , help = "The beta2 parameter for the Adam optimizer." )
@@ -394,10 +395,7 @@ def parse_args():
394395 type = str ,
395396 default = None ,
396397 choices = ["no" , "bf16" ],
397- help = (
398- "Whether to use mixed precision. Bf16 requires PyTorch >= 1.10"
399- ),
400-
398+ help = ("Whether to use mixed precision. Bf16 requires PyTorch >= 1.10" ),
401399 )
402400 parser .add_argument ("--push_to_hub" , action = "store_true" , help = "Whether or not to push the model to the Hub." )
403401 parser .add_argument ("--hub_token" , type = str , default = None , help = "The token to use to push to the Model Hub." )
@@ -407,6 +405,12 @@ def parse_args():
407405 default = None ,
408406 help = "The name of the repository to keep in sync with the local `output_dir`." ,
409407 )
408+ parser .add_argument (
409+ "--print_loss" ,
410+ default = False ,
411+ action = "store_true" ,
412+ help = ("Print loss at every step." ),
413+ )
410414
411415 args = parser .parse_args ()
412416
@@ -416,6 +420,7 @@ def parse_args():
416420
417421 return args
418422
423+
419424def setup_optimizer (unet , args ):
420425 optimizer_cls = torch .optim .AdamW
421426 return optimizer_cls (
@@ -427,6 +432,7 @@ def setup_optimizer(unet, args):
427432 foreach = True ,
428433 )
429434
435+
430436def load_dataset (args ):
431437 if args .dataset_name is not None :
432438 # Downloading and loading a dataset from the hub.
@@ -446,6 +452,7 @@ def load_dataset(args):
446452 )
447453 return dataset
448454
455+
449456def get_column_names (dataset , args ):
450457 column_names = dataset ["train" ].column_names
451458
@@ -470,13 +477,12 @@ def get_column_names(dataset, args):
470477
471478
472479def main (args ):
473-
474480 args = parse_args ()
475481
476- server = xp .start_server (9012 )
482+ _ = xp .start_server (9012 )
477483
478484 num_devices = xr .global_runtime_device_count ()
479- mesh = xs .get_1d_mesh (' data' )
485+ mesh = xs .get_1d_mesh (" data" )
480486 xs .set_global_mesh (mesh )
481487
482488 text_encoder = CLIPTextModel .from_pretrained (
@@ -511,6 +517,7 @@ def main(args):
511517 )
512518
513519 from torch_xla .distributed .fsdp .utils import apply_xla_patch_to_nn_linear
520+
514521 unet = apply_xla_patch_to_nn_linear (unet , xs .xla_patched_nn_linear_forward )
515522
516523 vae .requires_grad_ (False )
@@ -562,19 +569,9 @@ def tokenize_captions(examples, is_train=True):
562569
563570 train_transforms = transforms .Compose (
564571 [
565- transforms .Resize (
566- args .resolution , interpolation = transforms .InterpolationMode .BILINEAR
567- ),
568- (
569- transforms .CenterCrop (args .resolution )
570- if args .center_crop
571- else transforms .RandomCrop (args .resolution )
572- ),
573- (
574- transforms .RandomHorizontalFlip ()
575- if args .random_flip
576- else transforms .Lambda (lambda x : x )
577- ),
572+ transforms .Resize (args .resolution , interpolation = transforms .InterpolationMode .BILINEAR ),
573+ (transforms .CenterCrop (args .resolution ) if args .center_crop else transforms .RandomCrop (args .resolution )),
574+ (transforms .RandomHorizontalFlip () if args .random_flip else transforms .Lambda (lambda x : x )),
578575 transforms .ToTensor (),
579576 transforms .Normalize ([0.5 ], [0.5 ]),
580577 ]
@@ -592,17 +589,13 @@ def preprocess_train(examples):
592589
593590 def collate_fn (examples ):
594591 pixel_values = torch .stack ([example ["pixel_values" ] for example in examples ])
595- pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).to (
596- weight_dtype
597- )
592+ pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).to (weight_dtype )
598593 input_ids = torch .stack ([example ["input_ids" ] for example in examples ])
599594 return {"pixel_values" : pixel_values , "input_ids" : input_ids }
600595
601596 g = torch .Generator ()
602597 g .manual_seed (xr .host_index ())
603- sampler = torch .utils .data .RandomSampler (
604- train_dataset , replacement = True , num_samples = int (1e10 ), generator = g
605- )
598+ sampler = torch .utils .data .RandomSampler (train_dataset , replacement = True , num_samples = int (1e10 ), generator = g )
606599 train_dataloader = torch .utils .data .DataLoader (
607600 train_dataset ,
608601 sampler = sampler ,
@@ -616,9 +609,7 @@ def collate_fn(examples):
616609 train_dataloader ,
617610 device ,
618611 input_sharding = {
619- "pixel_values" : xs .ShardingSpec (
620- mesh , ("data" , None , None , None ), minibatch = True
621- ),
612+ "pixel_values" : xs .ShardingSpec (mesh , ("data" , None , None , None ), minibatch = True ),
622613 "input_ids" : xs .ShardingSpec (mesh , ("data" , None ), minibatch = True ),
623614 },
624615 loader_prefetch_size = args .loader_prefetch_size ,
@@ -635,15 +626,17 @@ def collate_fn(examples):
635626 )
636627 print (f" Total optimization steps = { args .max_train_steps } " )
637628
638- trainer = TrainSD (vae = vae ,
639- weight_dtype = weight_dtype ,
640- device = device ,
641- noise_scheduler = noise_scheduler ,
642- unet = unet ,
643- optimizer = optimizer ,
644- text_encoder = text_encoder ,
645- dataloader = train_dataloader ,
646- args = args )
629+ trainer = TrainSD (
630+ vae = vae ,
631+ weight_dtype = weight_dtype ,
632+ device = device ,
633+ noise_scheduler = noise_scheduler ,
634+ unet = unet ,
635+ optimizer = optimizer ,
636+ text_encoder = text_encoder ,
637+ dataloader = train_dataloader ,
638+ args = args ,
639+ )
647640
648641 trainer .start_training ()
649642 unet = trainer .unet .to ("cpu" )
@@ -672,4 +665,4 @@ def collate_fn(examples):
672665
673666if __name__ == "__main__" :
674667 args = parse_args ()
675- main (args )
668+ main (args )
0 commit comments