7575logger = get_logger (__name__ )
7676
7777
78+ def free_memory ():
79+ if torch .cuda .is_available ():
80+ torch .cuda .empty_cache ()
81+ if hasattr (torch , "xpu" ) and torch .xpu .is_available ():
82+ torch .xpu .empty_cache ()
83+
84+
7885def save_model_card (
7986 repo_id : str ,
8087 images = None ,
@@ -151,14 +158,14 @@ def log_validation(
151158 if args .validation_images is None :
152159 images = []
153160 for _ in range (args .num_validation_images ):
154- with torch .amp .autocast (pipeline .device .type ):
161+ with torch .amp .autocast (accelerator .device .type ):
155162 image = pipeline (** pipeline_args , generator = generator ).images [0 ]
156163 images .append (image )
157164 else :
158165 images = []
159166 for image in args .validation_images :
160167 image = Image .open (image )
161- with torch .amp .autocast (pipeline .device .type ):
168+ with torch .amp .autocast (accelerator .device .type ):
162169 image = pipeline (** pipeline_args , image = image , generator = generator ).images [0 ]
163170 images .append (image )
164171
@@ -177,10 +184,7 @@ def log_validation(
177184 )
178185
179186 del pipeline
180- if torch .cuda .is_available ():
181- torch .cuda .empty_cache ()
182- if hasattr (torch , "xpu" ) and torch .xpu .is_available ():
183- torch .xpu .empty_cache ()
187+ free_memory ()
184188
185189 return images
186190
@@ -832,10 +836,7 @@ def main(args):
832836 image .save (image_filename )
833837
834838 del pipeline
835- if torch .cuda .is_available ():
836- torch .cuda .empty_cache ()
837- if hasattr (torch , "xpu" ) and torch .xpu .is_available ():
838- torch .xpu .empty_cache ()
839+ free_memory ()
839840
840841 # Handle the repository creation
841842 if accelerator .is_main_process :
@@ -1090,10 +1091,7 @@ def compute_text_embeddings(prompt):
10901091 tokenizer = None
10911092
10921093 gc .collect ()
1093- if torch .cuda .is_available ():
1094- torch .cuda .empty_cache ()
1095- if hasattr (torch , "xpu" ) and torch .xpu .is_available ():
1096- torch .xpu .empty_cache ()
1094+ free_memory ()
10971095 else :
10981096 pre_computed_encoder_hidden_states = None
10991097 validation_prompt_encoder_hidden_states = None
0 commit comments