@@ -177,16 +177,24 @@ def log_validation(
177177        f"Running validation... \n  Generating { args .num_validation_images }  
178178        f" { args .validation_prompt }  
179179    )
180-     pipeline  =  pipeline .to (accelerator .device )
180+     pipeline  =  pipeline .to (accelerator .device ,  dtype = torch_dtype )
181181    pipeline .set_progress_bar_config (disable = True )
182182
183183    # run inference 
184184    generator  =  torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if  args .seed  is  not None  else  None 
185-     # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() 
186-     autocast_ctx  =  nullcontext ()
185+     autocast_ctx  =  torch .autocast (accelerator .device .type )
187186
188-     with  autocast_ctx :
189-         images  =  [pipeline (** pipeline_args , generator = generator ).images [0 ] for  _  in  range (args .num_validation_images )]
187+     prompt_embeds , pooled_prompt_embeds , text_ids  =  pipeline .encode_prompt (
188+         pipeline_args ["prompt" ], prompt_2 = pipeline_args ["prompt" ]
189+     )
190+     images  =  []
191+     for  _  in  range (args .num_validation_images ):
192+         with  autocast_ctx :
193+             image  =  pipeline (
194+                     prompt_embeds = prompt_embeds ,
195+                     pooled_prompt_embeds = pooled_prompt_embeds ,
196+                     generator = generator ).images [0 ]
197+             images .append (image )
190198
191199    for  tracker  in  accelerator .trackers :
192200        phase_name  =  "test"  if  is_final_validation  else  "validation" 
@@ -203,8 +211,7 @@ def log_validation(
203211            )
204212
205213    del  pipeline 
206-     if  torch .cuda .is_available ():
207-         torch .cuda .empty_cache ()
214+     free_memory ()
208215
209216    return  images 
210217
0 commit comments