@@ -941,7 +941,7 @@ def _encode_prompt_with_t5(
941941
942942    prompt_embeds  =  text_encoder (text_input_ids .to (device ))[0 ]
943943
944-     dtype  =  unwrap_model ( text_encoder ) .dtype 
944+     dtype  =  text_encoder .dtype 
945945    prompt_embeds  =  prompt_embeds .to (dtype = dtype , device = device )
946946
947947    _ , seq_len , _  =  prompt_embeds .shape 
@@ -984,7 +984,7 @@ def _encode_prompt_with_clip(
984984
985985    # Use pooled output of CLIPTextModel 
986986    prompt_embeds  =  prompt_embeds .pooler_output 
987-     prompt_embeds  =  prompt_embeds .to (dtype = unwrap_model ( text_encoder ) .dtype , device = device )
987+     prompt_embeds  =  prompt_embeds .to (dtype = text_encoder .dtype , device = device )
988988
989989    # duplicate text embeddings for each generation per prompt, using mps friendly method 
990990    prompt_embeds  =  prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
@@ -1003,7 +1003,7 @@ def encode_prompt(
10031003    text_input_ids_list = None ,
10041004):
10051005    prompt  =  [prompt ] if  isinstance (prompt , str ) else  prompt 
1006-     dtype  =  unwrap_model ( text_encoders [0 ]) .dtype 
1006+     dtype  =  text_encoders [0 ].dtype 
10071007
10081008    pooled_prompt_embeds  =  _encode_prompt_with_clip (
10091009        text_encoder = text_encoders [0 ],
0 commit comments