5050from diffusers .optimization import get_scheduler
5151from diffusers .utils import check_min_version , is_wandb_available
5252from diffusers .utils .import_utils import is_xformers_available
53+ from diffusers .utils .torch_utils import is_compiled_module
5354
5455
5556if is_wandb_available ():
@@ -787,6 +788,12 @@ def main(args):
787788 logger .info ("Initializing controlnet weights from unet" )
788789 controlnet = ControlNetModel .from_unet (unet )
789790
791+ # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
792+ def unwrap_model (model ):
793+ model = accelerator .unwrap_model (model )
794+ model = model ._orig_mod if is_compiled_module (model ) else model
795+ return model
796+
790797 # `accelerate` 0.16.0 will have better support for customized saving
791798 if version .parse (accelerate .__version__ ) >= version .parse ("0.16.0" ):
792799 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -846,9 +853,9 @@ def load_model_hook(models, input_dir):
846853 " doing mixed precision training, copy of the weights should still be float32."
847854 )
848855
849- if accelerator . unwrap_model (controlnet ).dtype != torch .float32 :
856+ if unwrap_model (controlnet ).dtype != torch .float32 :
850857 raise ValueError (
851- f"Controlnet loaded as datatype { accelerator . unwrap_model (controlnet ).dtype } . { low_precision_error_string } "
858+ f"Controlnet loaded as datatype { unwrap_model (controlnet ).dtype } . { low_precision_error_string } "
852859 )
853860
854861 # Enable TF32 for faster training on Ampere GPUs,
@@ -1015,7 +1022,7 @@ def load_model_hook(models, input_dir):
10151022 noisy_latents = noise_scheduler .add_noise (latents , noise , timesteps )
10161023
10171024 # Get the text embedding for conditioning
1018- encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
1025+ encoder_hidden_states = text_encoder (batch ["input_ids" ], return_dict = False )[0 ]
10191026
10201027 controlnet_image = batch ["conditioning_pixel_values" ].to (dtype = weight_dtype )
10211028
@@ -1036,7 +1043,8 @@ def load_model_hook(models, input_dir):
10361043 sample .to (dtype = weight_dtype ) for sample in down_block_res_samples
10371044 ],
10381045 mid_block_additional_residual = mid_block_res_sample .to (dtype = weight_dtype ),
1039- ).sample
1046+ return_dict = False ,
1047+ )[0 ]
10401048
10411049 # Get the target for loss depending on the prediction type
10421050 if noise_scheduler .config .prediction_type == "epsilon" :
@@ -1109,7 +1117,7 @@ def load_model_hook(models, input_dir):
11091117 # Create the pipeline using using the trained modules and save it.
11101118 accelerator .wait_for_everyone ()
11111119 if accelerator .is_main_process :
1112- controlnet = accelerator . unwrap_model (controlnet )
1120+ controlnet = unwrap_model (controlnet )
11131121 controlnet .save_pretrained (args .output_dir )
11141122
11151123 if args .push_to_hub :
0 commit comments