Skip to content

Commit 660a0bf

Browse files
committed
fix comment
Signed-off-by: jiqing-feng <[email protected]>
1 parent 52924fd commit 660a0bf

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@
7575
logger = get_logger(__name__)
7676

7777

78+
def free_memory():
79+
if torch.cuda.is_available():
80+
torch.cuda.empty_cache()
81+
if hasattr(torch, "xpu") and torch.xpu.is_available():
82+
torch.xpu.empty_cache()
83+
84+
7885
def save_model_card(
7986
repo_id: str,
8087
images=None,
@@ -151,14 +158,14 @@ def log_validation(
151158
if args.validation_images is None:
152159
images = []
153160
for _ in range(args.num_validation_images):
154-
with torch.amp.autocast(pipeline.device.type):
161+
with torch.amp.autocast(accelerator.device.type):
155162
image = pipeline(**pipeline_args, generator=generator).images[0]
156163
images.append(image)
157164
else:
158165
images = []
159166
for image in args.validation_images:
160167
image = Image.open(image)
161-
with torch.amp.autocast(pipeline.device.type):
168+
with torch.amp.autocast(accelerator.device.type):
162169
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
163170
images.append(image)
164171

@@ -177,10 +184,7 @@ def log_validation(
177184
)
178185

179186
del pipeline
180-
if torch.cuda.is_available():
181-
torch.cuda.empty_cache()
182-
if hasattr(torch, "xpu") and torch.xpu.is_available():
183-
torch.xpu.empty_cache()
187+
free_memory()
184188

185189
return images
186190

@@ -832,10 +836,7 @@ def main(args):
832836
image.save(image_filename)
833837

834838
del pipeline
835-
if torch.cuda.is_available():
836-
torch.cuda.empty_cache()
837-
if hasattr(torch, "xpu") and torch.xpu.is_available():
838-
torch.xpu.empty_cache()
839+
free_memory()
839840

840841
# Handle the repository creation
841842
if accelerator.is_main_process:
@@ -1090,10 +1091,7 @@ def compute_text_embeddings(prompt):
10901091
tokenizer = None
10911092

10921093
gc.collect()
1093-
if torch.cuda.is_available():
1094-
torch.cuda.empty_cache()
1095-
if hasattr(torch, "xpu") and torch.xpu.is_available():
1096-
torch.xpu.empty_cache()
1094+
free_memory()
10971095
else:
10981096
pre_computed_encoder_hidden_states = None
10991097
validation_prompt_encoder_hidden_states = None

0 commit comments

Comments
 (0)