Skip to content

Commit 012d08b

Browse files
authored
Enable dreambooth lora finetune example on other devices (#10602)
* enable dreambooth_lora on other devices Signed-off-by: jiqing-feng <[email protected]> * enable xpu Signed-off-by: jiqing-feng <[email protected]> * check cuda device before empty cache Signed-off-by: jiqing-feng <[email protected]> * fix comment Signed-off-by: jiqing-feng <[email protected]> * import free_memory Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]>
1 parent 4ace7d0 commit 012d08b

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,11 @@
5454
)
5555
from diffusers.loaders import StableDiffusionLoraLoaderMixin
5656
from diffusers.optimization import get_scheduler
57-
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
57+
from diffusers.training_utils import (
58+
_set_state_dict_into_text_encoder,
59+
cast_training_params,
60+
free_memory,
61+
)
5862
from diffusers.utils import (
5963
check_min_version,
6064
convert_state_dict_to_diffusers,
@@ -151,14 +155,14 @@ def log_validation(
151155
if args.validation_images is None:
152156
images = []
153157
for _ in range(args.num_validation_images):
154-
with torch.cuda.amp.autocast():
158+
with torch.amp.autocast(accelerator.device.type):
155159
image = pipeline(**pipeline_args, generator=generator).images[0]
156160
images.append(image)
157161
else:
158162
images = []
159163
for image in args.validation_images:
160164
image = Image.open(image)
161-
with torch.cuda.amp.autocast():
165+
with torch.amp.autocast(accelerator.device.type):
162166
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
163167
images.append(image)
164168

@@ -177,7 +181,7 @@ def log_validation(
177181
)
178182

179183
del pipeline
180-
torch.cuda.empty_cache()
184+
free_memory()
181185

182186
return images
183187

@@ -793,7 +797,7 @@ def main(args):
793797
cur_class_images = len(list(class_images_dir.iterdir()))
794798

795799
if cur_class_images < args.num_class_images:
796-
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
800+
torch_dtype = torch.float16 if accelerator.device.type in ("cuda", "xpu") else torch.float32
797801
if args.prior_generation_precision == "fp32":
798802
torch_dtype = torch.float32
799803
elif args.prior_generation_precision == "fp16":
@@ -829,8 +833,7 @@ def main(args):
829833
image.save(image_filename)
830834

831835
del pipeline
832-
if torch.cuda.is_available():
833-
torch.cuda.empty_cache()
836+
free_memory()
834837

835838
# Handle the repository creation
836839
if accelerator.is_main_process:
@@ -1085,7 +1088,7 @@ def compute_text_embeddings(prompt):
10851088
tokenizer = None
10861089

10871090
gc.collect()
1088-
torch.cuda.empty_cache()
1091+
free_memory()
10891092
else:
10901093
pre_computed_encoder_hidden_states = None
10911094
validation_prompt_encoder_hidden_states = None

src/diffusers/training_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ def free_memory():
299299
torch.mps.empty_cache()
300300
elif is_torch_npu_available():
301301
torch_npu.npu.empty_cache()
302+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
303+
torch.xpu.empty_cache()
302304

303305

304306
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14

0 commit comments

Comments
 (0)