4646from  diffusers .training_utils  import  EMAModel , compute_snr 
4747from  diffusers .utils  import  check_min_version , deprecate , is_wandb_available , make_image_grid 
4848from  diffusers .utils .import_utils  import  is_xformers_available 
49+ from  diffusers .utils .torch_utils  import  is_compiled_module 
4950
5051
5152if  is_wandb_available ():
@@ -833,6 +834,12 @@ def collate_fn(examples):
833834        tracker_config .pop ("validation_prompts" )
834835        accelerator .init_trackers (args .tracker_project_name , tracker_config )
835836
837+     # Function for unwrapping if model was compiled with `torch.compile`. 
838+     def  unwrap_model (model ):
839+         model  =  accelerator .unwrap_model (model )
840+         model  =  model ._orig_mod  if  is_compiled_module (model ) else  model 
841+         return  model 
842+ 
836843    # Train! 
837844    total_batch_size  =  args .train_batch_size  *  accelerator .num_processes  *  args .gradient_accumulation_steps 
838845
@@ -912,7 +919,7 @@ def collate_fn(examples):
912919                    noisy_latents  =  noise_scheduler .add_noise (latents , noise , timesteps )
913920
914921                # Get the text embedding for conditioning 
915-                 encoder_hidden_states  =  text_encoder (batch ["input_ids" ])[0 ]
922+                 encoder_hidden_states  =  text_encoder (batch ["input_ids" ],  return_dict = False )[0 ]
916923
917924                # Get the target for loss depending on the prediction type 
918925                if  args .prediction_type  is  not   None :
@@ -927,7 +934,7 @@ def collate_fn(examples):
927934                    raise  ValueError (f"Unknown prediction type { noise_scheduler .config .prediction_type }  " )
928935
929936                # Predict the noise residual and compute loss 
930-                 model_pred  =  unet (noisy_latents , timesteps , encoder_hidden_states ). sample 
937+                 model_pred  =  unet (noisy_latents , timesteps , encoder_hidden_states ,  return_dict = False )[ 0 ] 
931938
932939                if  args .snr_gamma  is  None :
933940                    loss  =  F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
@@ -1023,7 +1030,7 @@ def collate_fn(examples):
10231030    # Create the pipeline using the trained modules and save it. 
10241031    accelerator .wait_for_everyone ()
10251032    if  accelerator .is_main_process :
1026-         unet  =  accelerator . unwrap_model (unet )
1033+         unet  =  unwrap_model (unet )
10271034        if  args .use_ema :
10281035            ema_unet .copy_to (unet .parameters ())
10291036
0 commit comments