5555from diffusers .training_utils import compute_snr
5656from diffusers .utils import check_min_version , is_wandb_available
5757from diffusers .utils .import_utils import is_xformers_available
58+ from diffusers .utils .torch_utils import is_compiled_module
5859
5960
6061if is_wandb_available ():
@@ -129,15 +130,12 @@ def log_validation(
129130 if vae is not None :
130131 pipeline_args ["vae" ] = vae
131132
132- if text_encoder is not None :
133- text_encoder = accelerator .unwrap_model (text_encoder )
134-
135133 # create pipeline (note: unet and vae are loaded again in float32)
136134 pipeline = DiffusionPipeline .from_pretrained (
137135 args .pretrained_model_name_or_path ,
138136 tokenizer = tokenizer ,
139137 text_encoder = text_encoder ,
140- unet = accelerator . unwrap_model ( unet ) ,
138+ unet = unet ,
141139 revision = args .revision ,
142140 variant = args .variant ,
143141 torch_dtype = weight_dtype ,
@@ -794,6 +792,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
794792 prompt_embeds = text_encoder (
795793 text_input_ids ,
796794 attention_mask = attention_mask ,
795+ return_dict = False ,
797796 )
798797 prompt_embeds = prompt_embeds [0 ]
799798
@@ -931,11 +930,16 @@ def main(args):
931930 args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision , variant = args .variant
932931 )
933932
933+ def unwrap_model (model ):
934+ model = accelerator .unwrap_model (model )
935+ model = model ._orig_mod if is_compiled_module (model ) else model
936+ return model
937+
934938 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
935939 def save_model_hook (models , weights , output_dir ):
936940 if accelerator .is_main_process :
937941 for model in models :
938- sub_dir = "unet" if isinstance (model , type (accelerator . unwrap_model (unet ))) else "text_encoder"
942+ sub_dir = "unet" if isinstance (model , type (unwrap_model (unet ))) else "text_encoder"
939943 model .save_pretrained (os .path .join (output_dir , sub_dir ))
940944
941945 # make sure to pop weight so that corresponding model is not saved again
@@ -946,7 +950,7 @@ def load_model_hook(models, input_dir):
946950 # pop models so that they are not loaded again
947951 model = models .pop ()
948952
949- if isinstance (model , type (accelerator . unwrap_model (text_encoder ))):
953+ if isinstance (model , type (unwrap_model (text_encoder ))):
950954 # load transformers style into model
951955 load_model = text_encoder_cls .from_pretrained (input_dir , subfolder = "text_encoder" )
952956 model .config = load_model .config
@@ -991,15 +995,12 @@ def load_model_hook(models, input_dir):
991995 " doing mixed precision training. copy of the weights should still be float32."
992996 )
993997
994- if accelerator .unwrap_model (unet ).dtype != torch .float32 :
995- raise ValueError (
996- f"Unet loaded as datatype { accelerator .unwrap_model (unet ).dtype } . { low_precision_error_string } "
997- )
998+ if unwrap_model (unet ).dtype != torch .float32 :
999+ raise ValueError (f"Unet loaded as datatype { unwrap_model (unet ).dtype } . { low_precision_error_string } " )
9981000
999- if args .train_text_encoder and accelerator . unwrap_model (text_encoder ).dtype != torch .float32 :
1001+ if args .train_text_encoder and unwrap_model (text_encoder ).dtype != torch .float32 :
10001002 raise ValueError (
1001- f"Text encoder loaded as datatype { accelerator .unwrap_model (text_encoder ).dtype } ."
1002- f" { low_precision_error_string } "
1003+ f"Text encoder loaded as datatype { unwrap_model (text_encoder ).dtype } ." f" { low_precision_error_string } "
10031004 )
10041005
10051006 # Enable TF32 for faster training on Ampere GPUs,
@@ -1246,7 +1247,7 @@ def compute_text_embeddings(prompt):
12461247 text_encoder_use_attention_mask = args .text_encoder_use_attention_mask ,
12471248 )
12481249
1249- if accelerator . unwrap_model (unet ).config .in_channels == channels * 2 :
1250+ if unwrap_model (unet ).config .in_channels == channels * 2 :
12501251 noisy_model_input = torch .cat ([noisy_model_input , noisy_model_input ], dim = 1 )
12511252
12521253 if args .class_labels_conditioning == "timesteps" :
@@ -1256,8 +1257,8 @@ def compute_text_embeddings(prompt):
12561257
12571258 # Predict the noise residual
12581259 model_pred = unet (
1259- noisy_model_input , timesteps , encoder_hidden_states , class_labels = class_labels
1260- ). sample
1260+ noisy_model_input , timesteps , encoder_hidden_states , class_labels = class_labels , return_dict = False
1261+ )[ 0 ]
12611262
12621263 if model_pred .shape [1 ] == 6 :
12631264 model_pred , _ = torch .chunk (model_pred , 2 , dim = 1 )
@@ -1350,9 +1351,9 @@ def compute_text_embeddings(prompt):
13501351
13511352 if args .validation_prompt is not None and global_step % args .validation_steps == 0 :
13521353 images = log_validation (
1353- text_encoder ,
1354+ unwrap_model ( text_encoder ) if text_encoder is not None else text_encoder ,
13541355 tokenizer ,
1355- unet ,
1356+ unwrap_model ( unet ) ,
13561357 vae ,
13571358 args ,
13581359 accelerator ,
@@ -1375,14 +1376,14 @@ def compute_text_embeddings(prompt):
13751376 pipeline_args = {}
13761377
13771378 if text_encoder is not None :
1378- pipeline_args ["text_encoder" ] = accelerator . unwrap_model (text_encoder )
1379+ pipeline_args ["text_encoder" ] = unwrap_model (text_encoder )
13791380
13801381 if args .skip_save_text_encoder :
13811382 pipeline_args ["text_encoder" ] = None
13821383
13831384 pipeline = DiffusionPipeline .from_pretrained (
13841385 args .pretrained_model_name_or_path ,
1385- unet = accelerator . unwrap_model (unet ),
1386+ unet = unwrap_model (unet ),
13861387 revision = args .revision ,
13871388 variant = args .variant ,
13881389 ** pipeline_args ,
0 commit comments