@@ -225,9 +225,10 @@ def __init__(
225225        # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible 
226226        # by the patch size. So the vae scale factor is multiplied by the patch size to account for this 
227227        self .image_processor  =  VaeImageProcessor (vae_scale_factor = self .vae_scale_factor  *  2 )
228+         latent_channels  =  self .vae .config .latent_channels  if  getattr (self , "vae" , None ) else  16 
228229        self .mask_processor  =  VaeImageProcessor (
229230            vae_scale_factor = self .vae_scale_factor  *  2 ,
230-             vae_latent_channels = self . vae . config . latent_channels ,
231+             vae_latent_channels = latent_channels ,
231232            do_normalize = False ,
232233            do_binarize = True ,
233234            do_convert_grayscale = True ,
@@ -656,7 +657,7 @@ def disable_vae_tiling(self):
656657        """ 
657658        self .vae .disable_tiling ()
658659
659-     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline .prepare_latents 
660+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxImg2ImgPipeline .prepare_latents 
660661    def  prepare_latents (
661662        self ,
662663        image ,
@@ -670,20 +671,24 @@ def prepare_latents(
670671        generator ,
671672        latents = None ,
672673    ):
674+         if  isinstance (generator , list ) and  len (generator ) !=  batch_size :
675+             raise  ValueError (
676+                 f"You have passed a list of generators of length { len (generator )}  
677+                 f" size of { batch_size }  
678+             )
679+ 
673680        # VAE applies 8x compression on images but we must also account for packing which requires 
674681        # latent height and width to be divisible by 2. 
675682        height  =  2  *  (int (height ) //  (self .vae_scale_factor  *  2 ))
676683        width  =  2  *  (int (width ) //  (self .vae_scale_factor  *  2 ))
677- 
678684        shape  =  (batch_size , num_channels_latents , height , width )
685+         latent_image_ids  =  self ._prepare_latent_image_ids (batch_size , height  //  2 , width  //  2 , device , dtype )
686+ 
687+         if  latents  is  not None :
688+             return  latents .to (device = device , dtype = dtype ), latent_image_ids 
679689
680-         # if latents is not None: 
681690        image  =  image .to (device = device , dtype = dtype )
682691        image_latents  =  self ._encode_vae_image (image = image , generator = generator )
683- 
684-         latent_image_ids  =  self ._prepare_latent_image_ids (
685-             batch_size , height  //  2 , width  //  2 , device , dtype 
686-         )
687692        if  batch_size  >  image_latents .shape [0 ] and  batch_size  %  image_latents .shape [0 ] ==  0 :
688693            # expand init_latents for batch_size 
689694            additional_image_per_prompt  =  batch_size  //  image_latents .shape [0 ]
@@ -695,19 +700,10 @@ def prepare_latents(
695700        else :
696701            image_latents  =  torch .cat ([image_latents ], dim = 0 )
697702
698-         if  latents  is  None :
699-             noise  =  randn_tensor (shape , generator = generator , device = device , dtype = dtype )
700-             latents  =  self .scheduler .scale_noise (image_latents , timestep , noise )
701-         else :
702-             noise  =  latents .to (device )
703-             latents  =  noise 
704- 
705-         noise  =  self ._pack_latents (noise , batch_size , num_channels_latents , height , width )
706-         image_latents  =  self ._pack_latents (
707-             image_latents , batch_size , num_channels_latents , height , width 
708-         )
703+         noise  =  randn_tensor (shape , generator = generator , device = device , dtype = dtype )
704+         latents  =  self .scheduler .scale_noise (image_latents , timestep , noise )
709705        latents  =  self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
710-         return  latents , noise ,  image_latents ,  latent_image_ids 
706+         return  latents , latent_image_ids 
711707
712708    @property  
713709    def  guidance_scale (self ):
@@ -866,7 +862,6 @@ def __call__(
866862        self ._joint_attention_kwargs  =  joint_attention_kwargs 
867863        self ._interrupt  =  False 
868864
869-         original_image  =  image 
870865        init_image  =  self .image_processor .preprocess (image , height = height , width = width )
871866        init_image  =  init_image .to (dtype = torch .float32 )
872867
@@ -935,7 +930,7 @@ def __call__(
935930
936931        # 5. Prepare latent variables 
937932        num_channels_latents  =  self .vae .config .latent_channels 
938-         latents , noise ,  image_latents ,  latent_image_ids  =  self .prepare_latents (
933+         latents , latent_image_ids  =  self .prepare_latents (
939934            init_image ,
940935            latent_timestep ,
941936            batch_size  *  num_images_per_prompt ,
0 commit comments