@@ -222,11 +222,13 @@ def __init__(
222222        self .vae_scale_factor  =  2  **  (len (self .vae .config .block_out_channels ) -  1 ) if  getattr (self , "vae" , None ) else  8 
223223        # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible 
224224        # by the patch size. So the vae scale factor is multiplied by the patch size to account for this 
225-         self .image_processor  =  VaeImageProcessor (vae_scale_factor = self .vae_scale_factor  *  2 )
226-         latent_channels  =  self .vae .config .latent_channels  if  getattr (self , "vae" , None ) else  16 
225+         self .latent_channels  =  self .vae .config .latent_channels  if  getattr (self , "vae" , None ) else  16 
226+         self .image_processor  =  VaeImageProcessor (
227+             vae_scale_factor = self .vae_scale_factor  *  2 , vae_latent_channels = self .latent_channels 
228+         )
227229        self .mask_processor  =  VaeImageProcessor (
228230            vae_scale_factor = self .vae_scale_factor  *  2 ,
229-             vae_latent_channels = latent_channels ,
231+             vae_latent_channels = self . latent_channels ,
230232            do_normalize = False ,
231233            do_binarize = True ,
232234            do_convert_grayscale = True ,
@@ -653,7 +655,10 @@ def prepare_latents(
653655        latent_image_ids  =  self ._prepare_latent_image_ids (batch_size , height  //  2 , width  //  2 , device , dtype )
654656
655657        image  =  image .to (device = device , dtype = dtype )
656-         image_latents  =  self ._encode_vae_image (image = image , generator = generator )
658+         if  image .shape [1 ] !=  self .latent_channels :
659+             image_latents  =  self ._encode_vae_image (image = image , generator = generator )
660+         else :
661+             image_latents  =  image 
657662
658663        if  batch_size  >  image_latents .shape [0 ] and  batch_size  %  image_latents .shape [0 ] ==  0 :
659664            # expand init_latents for batch_size 
@@ -710,7 +715,9 @@ def prepare_mask_latents(
710715        else :
711716            masked_image_latents  =  retrieve_latents (self .vae .encode (masked_image ), generator = generator )
712717
713-         masked_image_latents  =  (masked_image_latents  -  self .vae .config .shift_factor ) *  self .vae .config .scaling_factor 
718+             masked_image_latents  =  (
719+                 masked_image_latents  -  self .vae .config .shift_factor 
720+             ) *  self .vae .config .scaling_factor 
714721
715722        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method 
716723        if  mask .shape [0 ] <  batch_size :
0 commit comments