Skip to content

Commit a343848

Browse files
committed
check cuda device before empty cache
Signed-off-by: jiqing-feng <[email protected]>
1 parent cd45e14 commit a343848

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ def log_validation(
177177
)
178178

179179
del pipeline
180-
torch.cuda.empty_cache()
180+
if torch.cuda.is_available():
181+
torch.cuda.empty_cache()
181182
if hasattr(torch, "xpu") and torch.xpu.is_available():
182183
torch.xpu.empty_cache()
183184

@@ -1089,7 +1090,8 @@ def compute_text_embeddings(prompt):
10891090
tokenizer = None
10901091

10911092
gc.collect()
1092-
torch.cuda.empty_cache()
1093+
if torch.cuda.is_available():
1094+
torch.cuda.empty_cache()
10931095
if hasattr(torch, "xpu") and torch.xpu.is_available():
10941096
torch.xpu.empty_cache()
10951097
else:

0 commit comments

Comments
 (0)