From 0381932cee443219da34e5831fef124aaadcfc28 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 7 Mar 2024 12:27:49 +0530 Subject: [PATCH] fix: prior preservation setting in DreamBooth LoRA SDXL script. --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 0d7876db9509..6e920d1a228c 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -877,6 +877,8 @@ def collate_fn(examples, with_prior_preservation=False): if with_prior_preservation: pixel_values += [example["class_images"] for example in examples] prompts += [example["class_prompt"] for example in examples] + original_sizes += [example["original_size"] for example in examples] + crop_top_lefts += [example["crop_top_left"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()