@@ -225,7 +225,10 @@ def __init__(
225225 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
226226 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
227227 # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
228- self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor * 2 )
228+ self .latent_channels = self .vae .config .latent_channels if getattr (self , "vae" , None ) else 16
229+ self .image_processor = VaeImageProcessor (
230+ vae_scale_factor = self .vae_scale_factor * 2 , vae_latent_channels = self .latent_channels
231+ )
229232 self .tokenizer_max_length = (
230233 self .tokenizer .model_max_length if hasattr (self , "tokenizer" ) and self .tokenizer is not None else 77
231234 )
@@ -634,7 +637,8 @@ def prepare_latents(
634637 return latents .to (device = device , dtype = dtype ), latent_image_ids
635638
636639 image = image .to (device = device , dtype = dtype )
637- image_latents = self ._encode_vae_image (image = image , generator = generator )
640+ if image .shape [1 ] != self .latent_channels :
641+ image_latents = self ._encode_vae_image (image = image , generator = generator )
638642 if batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] == 0 :
639643 # expand init_latents for batch_size
640644 additional_image_per_prompt = batch_size // image_latents .shape [0 ]
0 commit comments