diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 32bce9531b71..e5c092cb16ea 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -741,6 +741,9 @@ def __init__( self.buckets = buckets + # Initialize image processor for condition image preprocessing + self.image_processor = Flux2ImageProcessor() + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # we load the training data using load_dataset if args.dataset_name is not None: @@ -827,13 +830,13 @@ def __init__( dest_image = self.cond_images[i] image_width, image_height = dest_image.size if image_width * image_height > 1024 * 1024: - dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024) + dest_image = self.image_processor._resize_to_target_area(dest_image, 1024 * 1024) image_width, image_height = dest_image.size multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp! image_width = (image_width // multiple_of) * multiple_of image_height = (image_height // multiple_of) * multiple_of - dest_image = Flux2ImageProcessor.image_processor.preprocess( + dest_image = self.image_processor.preprocess( dest_image, height=image_height, width=image_width, resize_mode="crop" )