Skip to content

Commit 83343de

Browse files
committed
import free_memory
Signed-off-by: jiqing-feng <[email protected]>
1 parent 432700b commit 83343de

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,11 @@
5454
)
5555
from diffusers.loaders import StableDiffusionLoraLoaderMixin
5656
from diffusers.optimization import get_scheduler
57-
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
57+
from diffusers.training_utils import (
58+
_set_state_dict_into_text_encoder,
59+
cast_training_params,
60+
free_memory,
61+
)
5862
from diffusers.utils import (
5963
check_min_version,
6064
convert_state_dict_to_diffusers,
@@ -75,13 +79,6 @@
7579
logger = get_logger(__name__)
7680

7781

78-
def free_memory():
79-
if torch.cuda.is_available():
80-
torch.cuda.empty_cache()
81-
if hasattr(torch, "xpu") and torch.xpu.is_available():
82-
torch.xpu.empty_cache()
83-
84-
8582
def save_model_card(
8683
repo_id: str,
8784
images=None,

src/diffusers/training_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ def free_memory():
299299
torch.mps.empty_cache()
300300
elif is_torch_npu_available():
301301
torch_npu.npu.empty_cache()
302+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
303+
torch.xpu.empty_cache()
302304

303305

304306
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14

0 commit comments

Comments
 (0)