5454)
5555from diffusers .loaders import StableDiffusionLoraLoaderMixin
5656from diffusers .optimization import get_scheduler
57- from diffusers .training_utils import _set_state_dict_into_text_encoder , cast_training_params
57+ from diffusers .training_utils import (
58+ _set_state_dict_into_text_encoder ,
59+ cast_training_params ,
60+ free_memory ,
61+ )
5862from diffusers .utils import (
5963 check_min_version ,
6064 convert_state_dict_to_diffusers ,
@@ -151,14 +155,14 @@ def log_validation(
151155 if args .validation_images is None :
152156 images = []
153157 for _ in range (args .num_validation_images ):
154- with torch .cuda . amp .autocast ():
158+ with torch .amp .autocast (accelerator . device . type ):
155159 image = pipeline (** pipeline_args , generator = generator ).images [0 ]
156160 images .append (image )
157161 else :
158162 images = []
159163 for image in args .validation_images :
160164 image = Image .open (image )
161- with torch .cuda . amp .autocast ():
165+ with torch .amp .autocast (accelerator . device . type ):
162166 image = pipeline (** pipeline_args , image = image , generator = generator ).images [0 ]
163167 images .append (image )
164168
@@ -177,7 +181,7 @@ def log_validation(
177181 )
178182
179183 del pipeline
180- torch . cuda . empty_cache ()
184+ free_memory ()
181185
182186 return images
183187
@@ -793,7 +797,7 @@ def main(args):
793797 cur_class_images = len (list (class_images_dir .iterdir ()))
794798
795799 if cur_class_images < args .num_class_images :
796- torch_dtype = torch .float16 if accelerator .device .type == "cuda" else torch .float32
800+ torch_dtype = torch .float16 if accelerator .device .type in ( "cuda" , "xpu" ) else torch .float32
797801 if args .prior_generation_precision == "fp32" :
798802 torch_dtype = torch .float32
799803 elif args .prior_generation_precision == "fp16" :
@@ -829,8 +833,7 @@ def main(args):
829833 image .save (image_filename )
830834
831835 del pipeline
832- if torch .cuda .is_available ():
833- torch .cuda .empty_cache ()
836+ free_memory ()
834837
835838 # Handle the repository creation
836839 if accelerator .is_main_process :
@@ -1085,7 +1088,7 @@ def compute_text_embeddings(prompt):
10851088 tokenizer = None
10861089
10871090 gc .collect ()
1088- torch . cuda . empty_cache ()
1091+ free_memory ()
10891092 else :
10901093 pre_computed_encoder_hidden_states = None
10911094 validation_prompt_encoder_hidden_states = None
0 commit comments