@@ -224,11 +224,11 @@ def __init__(
224224        self .vae_scale_factor  =  2  **  (len (self .vae .config .block_out_channels ) -  1 ) if  getattr (self , "vae" , None ) else  8 
225225        # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible 
226226        # by the patch size. So the vae scale factor is multiplied by the patch size to account for this 
227-         self .image_processor  =  VaeImageProcessor ( vae_scale_factor = self .vae_scale_factor   *   2 ) 
228-         latent_channels   =   self .vae . config . latent_channels   if   getattr ( self ,  "vae" ,  None )  else   16 
227+         self .latent_channels  =  self .vae . config . latent_channels   if   getattr ( self ,  "vae" ,  None )  else   16 
228+         self .image_processor   =   VaeImageProcessor ( vae_scale_factor = self . vae_scale_factor   *   2 , vae_latent_channels = self . latent_channels ) 
229229        self .mask_processor  =  VaeImageProcessor (
230230            vae_scale_factor = self .vae_scale_factor  *  2 ,
231-             vae_latent_channels = latent_channels ,
231+             vae_latent_channels = self . latent_channels ,
232232            do_normalize = False ,
233233            do_binarize = True ,
234234            do_convert_grayscale = True ,
@@ -686,7 +686,10 @@ def prepare_latents(
686686            return  latents .to (device = device , dtype = dtype ), latent_image_ids 
687687
688688        image  =  image .to (device = device , dtype = dtype )
689-         image_latents  =  self ._encode_vae_image (image = image , generator = generator )
689+         if  image .shape [1 ] !=  self .latent_channels :
690+             image_latents  =  self ._encode_vae_image (image = image , generator = generator )
691+         else :
692+             image_latents  =  image 
690693        if  batch_size  >  image_latents .shape [0 ] and  batch_size  %  image_latents .shape [0 ] ==  0 :
691694            # expand init_latents for batch_size 
692695            additional_image_per_prompt  =  batch_size  //  image_latents .shape [0 ]
0 commit comments