diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 0d7876db9509..6e920d1a228c 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -877,6 +877,8 @@ def collate_fn(examples, with_prior_preservation=False): if with_prior_preservation: pixel_values += [example["class_images"] for example in examples] prompts += [example["class_prompt"] for example in examples] + original_sizes += [example["original_size"] for example in examples] + crop_top_lefts += [example["crop_top_left"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()