@@ -1196,7 +1196,10 @@ def _encode_prompt_with_t5(
11961196
11971197 prompt_embeds = text_encoder (text_input_ids .to (device ))[0 ]
11981198
1199- dtype = text_encoder .dtype
1199+ if hasattr (text_encoder , "module" ):
1200+ dtype = text_encoder .module .dtype
1201+ else :
1202+ dtype = text_encoder .dtype
12001203 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
12011204
12021205 _ , seq_len , _ = prompt_embeds .shape
@@ -1237,9 +1240,13 @@ def _encode_prompt_with_clip(
12371240
12381241 prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = False )
12391242
1243+ if hasattr (text_encoder , "module" ):
1244+ dtype = text_encoder .module .dtype
1245+ else :
1246+ dtype = text_encoder .dtype
12401247 # Use pooled output of CLIPTextModel
12411248 prompt_embeds = prompt_embeds .pooler_output
1242- prompt_embeds = prompt_embeds .to (dtype = text_encoder . dtype , device = device )
1249+ prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
12431250
12441251 # duplicate text embeddings for each generation per prompt, using mps friendly method
12451252 prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
@@ -1258,7 +1265,10 @@ def encode_prompt(
12581265 text_input_ids_list = None ,
12591266):
12601267 prompt = [prompt ] if isinstance (prompt , str ) else prompt
1261- dtype = text_encoders [0 ].dtype
1268+ if hasattr (text_encoders [0 ], "module" ):
1269+ dtype = text_encoders [0 ].module .dtype
1270+ else :
1271+ dtype = text_encoders [0 ].dtype
12621272
12631273 pooled_prompt_embeds = _encode_prompt_with_clip (
12641274 text_encoder = text_encoders [0 ],
@@ -2040,7 +2050,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20402050 if args .train_text_encoder :
20412051 text_encoder_one .train ()
20422052 # set top parameter requires_grad = True for gradient checkpointing works
2043- accelerator . unwrap_model (text_encoder_one ).text_model .embeddings .requires_grad_ (True )
2053+ unwrap_model (text_encoder_one ).text_model .embeddings .requires_grad_ (True )
20442054 elif args .train_text_encoder_ti : # textual inversion / pivotal tuning
20452055 text_encoder_one .train ()
20462056 if args .enable_t5_ti :
@@ -2148,7 +2158,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21482158 )
21492159
21502160 # handle guidance
2151- if accelerator . unwrap_model (transformer ).config .guidance_embeds :
2161+ if unwrap_model (transformer ).config .guidance_embeds :
21522162 guidance = torch .tensor ([args .guidance_scale ], device = accelerator .device )
21532163 guidance = guidance .expand (model_input .shape [0 ])
21542164 else :
@@ -2290,9 +2300,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
22902300 pipeline = FluxPipeline .from_pretrained (
22912301 args .pretrained_model_name_or_path ,
22922302 vae = vae ,
2293- text_encoder = accelerator . unwrap_model (text_encoder_one ),
2294- text_encoder_2 = accelerator . unwrap_model (text_encoder_two ),
2295- transformer = accelerator . unwrap_model (transformer ),
2303+ text_encoder = unwrap_model (text_encoder_one ),
2304+ text_encoder_2 = unwrap_model (text_encoder_two ),
2305+ transformer = unwrap_model (transformer ),
22962306 revision = args .revision ,
22972307 variant = args .variant ,
22982308 torch_dtype = weight_dtype ,
0 commit comments