@@ -941,7 +941,10 @@ def _encode_prompt_with_t5(
941941
942942 prompt_embeds = text_encoder (text_input_ids .to (device ))[0 ]
943943
944- dtype = text_encoder .dtype
944+ if hasattr (text_encoder , "module" ):
945+ dtype = text_encoder .module .dtype
946+ else :
947+ dtype = text_encoder .dtype
945948 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
946949
947950 _ , seq_len , _ = prompt_embeds .shape
@@ -982,9 +985,13 @@ def _encode_prompt_with_clip(
982985
983986 prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = False )
984987
988+ if hasattr (text_encoder , "module" ):
989+ dtype = text_encoder .module .dtype
990+ else :
991+ dtype = text_encoder .dtype
985992 # Use pooled output of CLIPTextModel
986993 prompt_embeds = prompt_embeds .pooler_output
987- prompt_embeds = prompt_embeds .to (dtype = text_encoder . dtype , device = device )
994+ prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
988995
989996 # duplicate text embeddings for each generation per prompt, using mps friendly method
990997 prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
@@ -1003,7 +1010,11 @@ def encode_prompt(
10031010 text_input_ids_list = None ,
10041011):
10051012 prompt = [prompt ] if isinstance (prompt , str ) else prompt
1006- dtype = text_encoders [0 ].dtype
1013+
1014+ if hasattr (text_encoders [0 ], "module" ):
1015+ dtype = text_encoders [0 ].module .dtype
1016+ else :
1017+ dtype = text_encoders [0 ].dtype
10071018
10081019 pooled_prompt_embeds = _encode_prompt_with_clip (
10091020 text_encoder = text_encoders [0 ],
@@ -1628,7 +1639,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16281639 if args .train_text_encoder :
16291640 text_encoder_one .train ()
16301641 # set top parameter requires_grad = True for gradient checkpointing works
1631- accelerator . unwrap_model (text_encoder_one ).text_model .embeddings .requires_grad_ (True )
1642+ unwrap_model (text_encoder_one ).text_model .embeddings .requires_grad_ (True )
16321643
16331644 for step , batch in enumerate (train_dataloader ):
16341645 models_to_accumulate = [transformer ]
@@ -1719,7 +1730,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17191730 )
17201731
17211732 # handle guidance
1722- if accelerator . unwrap_model (transformer ).config .guidance_embeds :
1733+ if unwrap_model (transformer ).config .guidance_embeds :
17231734 guidance = torch .tensor ([args .guidance_scale ], device = accelerator .device )
17241735 guidance = guidance .expand (model_input .shape [0 ])
17251736 else :
@@ -1837,9 +1848,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18371848 pipeline = FluxPipeline .from_pretrained (
18381849 args .pretrained_model_name_or_path ,
18391850 vae = vae ,
1840- text_encoder = accelerator . unwrap_model (text_encoder_one ),
1841- text_encoder_2 = accelerator . unwrap_model (text_encoder_two ),
1842- transformer = accelerator . unwrap_model (transformer ),
1851+ text_encoder = unwrap_model (text_encoder_one ),
1852+ text_encoder_2 = unwrap_model (text_encoder_two ),
1853+ transformer = unwrap_model (transformer ),
18431854 revision = args .revision ,
18441855 variant = args .variant ,
18451856 torch_dtype = weight_dtype ,
0 commit comments