@@ -225,9 +225,10 @@ def __init__(
225225 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
226226 # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
227227 self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor * 2 )
228+ latent_channels = self .vae .config .latent_channels if getattr (self , "vae" , None ) else 16
228229 self .mask_processor = VaeImageProcessor (
229230 vae_scale_factor = self .vae_scale_factor * 2 ,
230- vae_latent_channels = self . vae . config . latent_channels ,
231+ vae_latent_channels = latent_channels ,
231232 do_normalize = False ,
232233 do_binarize = True ,
233234 do_convert_grayscale = True ,
@@ -656,7 +657,7 @@ def disable_vae_tiling(self):
656657 """
657658 self .vae .disable_tiling ()
658659
659- # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline .prepare_latents
660+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxImg2ImgPipeline .prepare_latents
660661 def prepare_latents (
661662 self ,
662663 image ,
@@ -670,20 +671,24 @@ def prepare_latents(
670671 generator ,
671672 latents = None ,
672673 ):
674+ if isinstance (generator , list ) and len (generator ) != batch_size :
675+ raise ValueError (
676+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
677+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
678+ )
679+
673680 # VAE applies 8x compression on images but we must also account for packing which requires
674681 # latent height and width to be divisible by 2.
675682 height = 2 * (int (height ) // (self .vae_scale_factor * 2 ))
676683 width = 2 * (int (width ) // (self .vae_scale_factor * 2 ))
677-
678684 shape = (batch_size , num_channels_latents , height , width )
685+ latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
686+
687+ if latents is not None :
688+ return latents .to (device = device , dtype = dtype ), latent_image_ids
679689
680- # if latents is not None:
681690 image = image .to (device = device , dtype = dtype )
682691 image_latents = self ._encode_vae_image (image = image , generator = generator )
683-
684- latent_image_ids = self ._prepare_latent_image_ids (
685- batch_size , height // 2 , width // 2 , device , dtype
686- )
687692 if batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] == 0 :
688693 # expand init_latents for batch_size
689694 additional_image_per_prompt = batch_size // image_latents .shape [0 ]
@@ -695,19 +700,10 @@ def prepare_latents(
695700 else :
696701 image_latents = torch .cat ([image_latents ], dim = 0 )
697702
698- if latents is None :
699- noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
700- latents = self .scheduler .scale_noise (image_latents , timestep , noise )
701- else :
702- noise = latents .to (device )
703- latents = noise
704-
705- noise = self ._pack_latents (noise , batch_size , num_channels_latents , height , width )
706- image_latents = self ._pack_latents (
707- image_latents , batch_size , num_channels_latents , height , width
708- )
703+ noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
704+ latents = self .scheduler .scale_noise (image_latents , timestep , noise )
709705 latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
710- return latents , noise , image_latents , latent_image_ids
706+ return latents , latent_image_ids
711707
712708 @property
713709 def guidance_scale (self ):
@@ -866,7 +862,6 @@ def __call__(
866862 self ._joint_attention_kwargs = joint_attention_kwargs
867863 self ._interrupt = False
868864
869- original_image = image
870865 init_image = self .image_processor .preprocess (image , height = height , width = width )
871866 init_image = init_image .to (dtype = torch .float32 )
872867
@@ -935,7 +930,7 @@ def __call__(
935930
936931 # 5. Prepare latent variables
937932 num_channels_latents = self .vae .config .latent_channels
938- latents , noise , image_latents , latent_image_ids = self .prepare_latents (
933+ latents , latent_image_ids = self .prepare_latents (
939934 init_image ,
940935 latent_timestep ,
941936 batch_size * num_images_per_prompt ,
0 commit comments