54
54
)
55
55
from diffusers .loaders import StableDiffusionLoraLoaderMixin
56
56
from diffusers .optimization import get_scheduler
57
- from diffusers .training_utils import _set_state_dict_into_text_encoder , cast_training_params
57
+ from diffusers .training_utils import (
58
+ _set_state_dict_into_text_encoder ,
59
+ cast_training_params ,
60
+ free_memory ,
61
+ )
58
62
from diffusers .utils import (
59
63
check_min_version ,
60
64
convert_state_dict_to_diffusers ,
@@ -151,14 +155,14 @@ def log_validation(
151
155
if args .validation_images is None :
152
156
images = []
153
157
for _ in range (args .num_validation_images ):
154
- with torch .cuda . amp .autocast ():
158
+ with torch .amp .autocast (accelerator . device . type ):
155
159
image = pipeline (** pipeline_args , generator = generator ).images [0 ]
156
160
images .append (image )
157
161
else :
158
162
images = []
159
163
for image in args .validation_images :
160
164
image = Image .open (image )
161
- with torch .cuda . amp .autocast ():
165
+ with torch .amp .autocast (accelerator . device . type ):
162
166
image = pipeline (** pipeline_args , image = image , generator = generator ).images [0 ]
163
167
images .append (image )
164
168
@@ -177,7 +181,7 @@ def log_validation(
177
181
)
178
182
179
183
del pipeline
180
- torch . cuda . empty_cache ()
184
+ free_memory ()
181
185
182
186
return images
183
187
@@ -793,7 +797,7 @@ def main(args):
793
797
cur_class_images = len (list (class_images_dir .iterdir ()))
794
798
795
799
if cur_class_images < args .num_class_images :
796
- torch_dtype = torch .float16 if accelerator .device .type == "cuda" else torch .float32
800
+ torch_dtype = torch .float16 if accelerator .device .type in ( "cuda" , "xpu" ) else torch .float32
797
801
if args .prior_generation_precision == "fp32" :
798
802
torch_dtype = torch .float32
799
803
elif args .prior_generation_precision == "fp16" :
@@ -829,8 +833,7 @@ def main(args):
829
833
image .save (image_filename )
830
834
831
835
del pipeline
832
- if torch .cuda .is_available ():
833
- torch .cuda .empty_cache ()
836
+ free_memory ()
834
837
835
838
# Handle the repository creation
836
839
if accelerator .is_main_process :
@@ -1085,7 +1088,7 @@ def compute_text_embeddings(prompt):
1085
1088
tokenizer = None
1086
1089
1087
1090
gc .collect ()
1088
- torch . cuda . empty_cache ()
1091
+ free_memory ()
1089
1092
else :
1090
1093
pre_computed_encoder_hidden_states = None
1091
1094
validation_prompt_encoder_hidden_states = None
0 commit comments