5656from diffusers .optimization import get_scheduler
5757from diffusers .utils import check_min_version , convert_state_dict_to_diffusers , is_wandb_available
5858from diffusers .utils .import_utils import is_xformers_available
59+ from diffusers .utils .torch_utils import is_compiled_module
5960
6061
6162# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -647,6 +648,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
647648 prompt_embeds = text_encoder (
648649 text_input_ids ,
649650 attention_mask = attention_mask ,
651+ return_dict = False ,
650652 )
651653 prompt_embeds = prompt_embeds [0 ]
652654
@@ -843,6 +845,11 @@ def main(args):
843845 )
844846 text_encoder .add_adapter (text_lora_config )
845847
848+ def unwrap_model (model ):
849+ model = accelerator .unwrap_model (model )
850+ model = model ._orig_mod if is_compiled_module (model ) else model
851+ return model
852+
846853 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
847854 def save_model_hook (models , weights , output_dir ):
848855 if accelerator .is_main_process :
@@ -852,9 +859,9 @@ def save_model_hook(models, weights, output_dir):
852859 text_encoder_lora_layers_to_save = None
853860
854861 for model in models :
855- if isinstance (model , type (accelerator . unwrap_model (unet ))):
862+ if isinstance (model , type (unwrap_model (unet ))):
856863 unet_lora_layers_to_save = convert_state_dict_to_diffusers (get_peft_model_state_dict (model ))
857- elif isinstance (model , type (accelerator . unwrap_model (text_encoder ))):
864+ elif isinstance (model , type (unwrap_model (text_encoder ))):
858865 text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers (
859866 get_peft_model_state_dict (model )
860867 )
@@ -877,9 +884,9 @@ def load_model_hook(models, input_dir):
877884 while len (models ) > 0 :
878885 model = models .pop ()
879886
880- if isinstance (model , type (accelerator . unwrap_model (unet ))):
887+ if isinstance (model , type (unwrap_model (unet ))):
881888 unet_ = model
882- elif isinstance (model , type (accelerator . unwrap_model (text_encoder ))):
889+ elif isinstance (model , type (unwrap_model (text_encoder ))):
883890 text_encoder_ = model
884891 else :
885892 raise ValueError (f"unexpected save model: { model .__class__ } " )
@@ -1118,7 +1125,7 @@ def compute_text_embeddings(prompt):
11181125 text_encoder_use_attention_mask = args .text_encoder_use_attention_mask ,
11191126 )
11201127
1121- if accelerator . unwrap_model (unet ).config .in_channels == channels * 2 :
1128+ if unwrap_model (unet ).config .in_channels == channels * 2 :
11221129 noisy_model_input = torch .cat ([noisy_model_input , noisy_model_input ], dim = 1 )
11231130
11241131 if args .class_labels_conditioning == "timesteps" :
@@ -1128,8 +1135,12 @@ def compute_text_embeddings(prompt):
11281135
11291136 # Predict the noise residual
11301137 model_pred = unet (
1131- noisy_model_input , timesteps , encoder_hidden_states , class_labels = class_labels
1132- ).sample
1138+ noisy_model_input ,
1139+ timesteps ,
1140+ encoder_hidden_states ,
1141+ class_labels = class_labels ,
1142+ return_dict = False ,
1143+ )[0 ]
11331144
11341145 # if model predicts variance, throw away the prediction. we will only train on the
11351146 # simplified training objective. This means that all schedulers using the fine tuned
@@ -1215,8 +1226,8 @@ def compute_text_embeddings(prompt):
12151226 # create pipeline
12161227 pipeline = DiffusionPipeline .from_pretrained (
12171228 args .pretrained_model_name_or_path ,
1218- unet = accelerator . unwrap_model (unet ),
1219- text_encoder = None if args .pre_compute_text_embeddings else accelerator . unwrap_model (text_encoder ),
1229+ unet = unwrap_model (unet ),
1230+ text_encoder = None if args .pre_compute_text_embeddings else unwrap_model (text_encoder ),
12201231 revision = args .revision ,
12211232 variant = args .variant ,
12221233 torch_dtype = weight_dtype ,
@@ -1284,13 +1295,13 @@ def compute_text_embeddings(prompt):
12841295 # Save the lora layers
12851296 accelerator .wait_for_everyone ()
12861297 if accelerator .is_main_process :
1287- unet = accelerator . unwrap_model (unet )
1298+ unet = unwrap_model (unet )
12881299 unet = unet .to (torch .float32 )
12891300
12901301 unet_lora_state_dict = convert_state_dict_to_diffusers (get_peft_model_state_dict (unet ))
12911302
12921303 if args .train_text_encoder :
1293- text_encoder = accelerator . unwrap_model (text_encoder )
1304+ text_encoder = unwrap_model (text_encoder )
12941305 text_encoder_state_dict = convert_state_dict_to_diffusers (get_peft_model_state_dict (text_encoder ))
12951306 else :
12961307 text_encoder_state_dict = None
0 commit comments