@@ -222,11 +222,13 @@ def __init__(
222222 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
223223 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
224224 # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
225- self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor * 2 )
226- latent_channels = self .vae .config .latent_channels if getattr (self , "vae" , None ) else 16
225+ self .latent_channels = self .vae .config .latent_channels if getattr (self , "vae" , None ) else 16
226+ self .image_processor = VaeImageProcessor (
227+ vae_scale_factor = self .vae_scale_factor * 2 , vae_latent_channels = self .latent_channels
228+ )
227229 self .mask_processor = VaeImageProcessor (
228230 vae_scale_factor = self .vae_scale_factor * 2 ,
229- vae_latent_channels = latent_channels ,
231+ vae_latent_channels = self . latent_channels ,
230232 do_normalize = False ,
231233 do_binarize = True ,
232234 do_convert_grayscale = True ,
@@ -653,7 +655,10 @@ def prepare_latents(
653655 latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
654656
655657 image = image .to (device = device , dtype = dtype )
656- image_latents = self ._encode_vae_image (image = image , generator = generator )
658+ if image .shape [1 ] != self .latent_channels :
659+ image_latents = self ._encode_vae_image (image = image , generator = generator )
660+ else :
661+ image_latents = image
657662
658663 if batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] == 0 :
659664 # expand init_latents for batch_size
@@ -710,7 +715,9 @@ def prepare_mask_latents(
710715 else :
711716 masked_image_latents = retrieve_latents (self .vae .encode (masked_image ), generator = generator )
712717
713- masked_image_latents = (masked_image_latents - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
718+ masked_image_latents = (
719+ masked_image_latents - self .vae .config .shift_factor
720+ ) * self .vae .config .scaling_factor
714721
715722 # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
716723 if mask .shape [0 ] < batch_size :
0 commit comments