@@ -185,9 +185,10 @@ def log_validation(
185185 autocast_ctx = torch .autocast (accelerator .device .type )
186186
187187 # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
188- prompt_embeds , pooled_prompt_embeds , text_ids = pipeline .encode_prompt (
189- pipeline_args ["prompt" ], prompt_2 = pipeline_args ["prompt" ]
190- )
188+ with torch .no_grad ():
189+ prompt_embeds , pooled_prompt_embeds , text_ids = pipeline .encode_prompt (
190+ pipeline_args ["prompt" ], prompt_2 = pipeline_args ["prompt" ]
191+ )
191192 images = []
192193 for _ in range (args .num_validation_images ):
193194 with autocast_ctx :
@@ -940,7 +941,7 @@ def _encode_prompt_with_t5(
940941
941942 prompt_embeds = text_encoder (text_input_ids .to (device ))[0 ]
942943
943- dtype = text_encoder .dtype
944+ dtype = unwrap_model ( text_encoder ) .dtype
944945 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
945946
946947 _ , seq_len , _ = prompt_embeds .shape
@@ -983,7 +984,7 @@ def _encode_prompt_with_clip(
983984
984985 # Use pooled output of CLIPTextModel
985986 prompt_embeds = prompt_embeds .pooler_output
986- prompt_embeds = prompt_embeds .to (dtype = text_encoder .dtype , device = device )
987+ prompt_embeds = prompt_embeds .to (dtype = unwrap_model ( text_encoder ) .dtype , device = device )
987988
988989 # duplicate text embeddings for each generation per prompt, using mps friendly method
989990 prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
@@ -1002,7 +1003,7 @@ def encode_prompt(
10021003 text_input_ids_list = None ,
10031004):
10041005 prompt = [prompt ] if isinstance (prompt , str ) else prompt
1005- dtype = text_encoders [0 ].dtype
1006+ dtype = unwrap_model ( text_encoders [0 ]) .dtype
10061007
10071008 pooled_prompt_embeds = _encode_prompt_with_clip (
10081009 text_encoder = text_encoders [0 ],
0 commit comments