@@ -895,7 +895,10 @@ def _encode_prompt_with_t5(
895895
896896 prompt_embeds = text_encoder (text_input_ids .to (device ))[0 ]
897897
898- dtype = text_encoder .dtype
898+ if hasattr (text_encoder , "module" ):
899+ dtype = text_encoder .module .dtype
900+ else :
901+ dtype = text_encoder .dtype
899902 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
900903
901904 _ , seq_len , _ = prompt_embeds .shape
@@ -936,9 +939,13 @@ def _encode_prompt_with_clip(
936939
937940 prompt_embeds = text_encoder (text_input_ids .to (device ), output_hidden_states = False )
938941
942+ if hasattr (text_encoder , "module" ):
943+ dtype = text_encoder .module .dtype
944+ else :
945+ dtype = text_encoder .dtype
939946 # Use pooled output of CLIPTextModel
940947 prompt_embeds = prompt_embeds .pooler_output
941- prompt_embeds = prompt_embeds .to (dtype = text_encoder . dtype , device = device )
948+ prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
942949
943950 # duplicate text embeddings for each generation per prompt, using mps friendly method
944951 prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
@@ -958,7 +965,12 @@ def encode_prompt(
958965):
959966 prompt = [prompt ] if isinstance (prompt , str ) else prompt
960967 batch_size = len (prompt )
961- dtype = text_encoders [0 ].dtype
968+
969+ if hasattr (text_encoders [0 ], "module" ):
970+ dtype = text_encoders [0 ].module .dtype
971+ else :
972+ dtype = text_encoders [0 ].dtype
973+
962974 device = device if device is not None else text_encoders [1 ].device
963975 pooled_prompt_embeds = _encode_prompt_with_clip (
964976 text_encoder = text_encoders [0 ],
0 commit comments