Skip to content
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@
logger = get_logger(__name__)


def free_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
if hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.empty_cache()


def save_model_card(
repo_id: str,
images=None,
Expand Down Expand Up @@ -151,14 +158,14 @@ def log_validation(
if args.validation_images is None:
images = []
for _ in range(args.num_validation_images):
with torch.cuda.amp.autocast():
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image)
else:
images = []
for image in args.validation_images:
image = Image.open(image)
with torch.cuda.amp.autocast():
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)

Expand All @@ -177,7 +184,7 @@ def log_validation(
)

del pipeline
torch.cuda.empty_cache()
free_memory()

return images

Expand Down Expand Up @@ -793,7 +800,7 @@ def main(args):
cur_class_images = len(list(class_images_dir.iterdir()))

if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
torch_dtype = torch.float16 if accelerator.device.type in ("cuda", "xpu") else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
Expand Down Expand Up @@ -829,8 +836,7 @@ def main(args):
image.save(image_filename)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -1085,7 +1091,7 @@ def compute_text_embeddings(prompt):
tokenizer = None

gc.collect()
torch.cuda.empty_cache()
free_memory()
else:
pre_computed_encoder_hidden_states = None
validation_prompt_encoder_hidden_states = None
Expand Down