5454)
5555from  diffusers .loaders  import  StableDiffusionLoraLoaderMixin 
5656from  diffusers .optimization  import  get_scheduler 
57- from  diffusers .training_utils  import  _set_state_dict_into_text_encoder , cast_training_params 
57+ from  diffusers .training_utils  import  (
58+     _set_state_dict_into_text_encoder ,
59+     cast_training_params ,
60+     free_memory ,
61+ )
5862from  diffusers .utils  import  (
5963    check_min_version ,
6064    convert_state_dict_to_diffusers ,
@@ -151,14 +155,14 @@ def log_validation(
151155    if  args .validation_images  is  None :
152156        images  =  []
153157        for  _  in  range (args .num_validation_images ):
154-             with  torch .cuda . amp .autocast ():
158+             with  torch .amp .autocast (accelerator . device . type ):
155159                image  =  pipeline (** pipeline_args , generator = generator ).images [0 ]
156160                images .append (image )
157161    else :
158162        images  =  []
159163        for  image  in  args .validation_images :
160164            image  =  Image .open (image )
161-             with  torch .cuda . amp .autocast ():
165+             with  torch .amp .autocast (accelerator . device . type ):
162166                image  =  pipeline (** pipeline_args , image = image , generator = generator ).images [0 ]
163167            images .append (image )
164168
@@ -177,7 +181,7 @@ def log_validation(
177181            )
178182
179183    del  pipeline 
180-     torch . cuda . empty_cache ()
184+     free_memory ()
181185
182186    return  images 
183187
@@ -793,7 +797,7 @@ def main(args):
793797        cur_class_images  =  len (list (class_images_dir .iterdir ()))
794798
795799        if  cur_class_images  <  args .num_class_images :
796-             torch_dtype  =  torch .float16  if  accelerator .device .type  ==   "cuda"  else  torch .float32 
800+             torch_dtype  =  torch .float16  if  accelerator .device .type  in  ( "cuda" ,  "xpu" )  else  torch .float32 
797801            if  args .prior_generation_precision  ==  "fp32" :
798802                torch_dtype  =  torch .float32 
799803            elif  args .prior_generation_precision  ==  "fp16" :
@@ -829,8 +833,7 @@ def main(args):
829833                    image .save (image_filename )
830834
831835            del  pipeline 
832-             if  torch .cuda .is_available ():
833-                 torch .cuda .empty_cache ()
836+             free_memory ()
834837
835838    # Handle the repository creation 
836839    if  accelerator .is_main_process :
@@ -1085,7 +1088,7 @@ def compute_text_embeddings(prompt):
10851088        tokenizer  =  None 
10861089
10871090        gc .collect ()
1088-         torch . cuda . empty_cache ()
1091+         free_memory ()
10891092    else :
10901093        pre_computed_encoder_hidden_states  =  None 
10911094        validation_prompt_encoder_hidden_states  =  None 
0 commit comments