Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@
)
from diffusers.loaders import StableDiffusionLoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
free_memory,
)
from diffusers.utils import (
check_min_version,
convert_state_dict_to_diffusers,
Expand Down Expand Up @@ -151,14 +155,14 @@ def log_validation(
if args.validation_images is None:
images = []
for _ in range(args.num_validation_images):
with torch.cuda.amp.autocast():
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image)
else:
images = []
for image in args.validation_images:
image = Image.open(image)
with torch.cuda.amp.autocast():
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)

Expand All @@ -177,7 +181,7 @@ def log_validation(
)

del pipeline
torch.cuda.empty_cache()
free_memory()

return images

Expand Down Expand Up @@ -793,7 +797,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))

if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
torch_dtype = torch.float16 if accelerator.device.type in ("cuda", "xpu") else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
Expand Down Expand Up @@ -829,8 +833,7 @@ def main(args):
image.save(image_filename)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -1085,7 +1088,7 @@ def compute_text_embeddings(prompt):
tokenizer = None

gc.collect()
torch.cuda.empty_cache()
free_memory()
else:
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ def free_memory():
torch.mps.empty_cache()
elif is_torch_npu_available():
torch_npu.npu.empty_cache()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.empty_cache()


# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
Expand Down
Loading