@@ -233,9 +233,11 @@ def __init__(
233233 self .vae_scale_factor = (
234234 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if hasattr (self , "vae" ) and self .vae is not None else 8
235235 )
236- self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor )
236+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
237+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
238+ self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor * 2 )
237239 self .mask_processor = VaeImageProcessor (
238- vae_scale_factor = self .vae_scale_factor ,
240+ vae_scale_factor = self .vae_scale_factor * 2 ,
239241 vae_latent_channels = self .vae .config .latent_channels ,
240242 do_normalize = False ,
241243 do_binarize = True ,
@@ -467,9 +469,9 @@ def check_inputs(
467469 if strength < 0 or strength > 1 :
468470 raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
469471
470- if height % self .vae_scale_factor != 0 or width % self .vae_scale_factor != 0 :
471- raise ValueError (
472- f"`height` and `width` have to be divisible by { self .vae_scale_factor } but are { height } and { width } ."
472+ if height % ( self .vae_scale_factor * 2 ) != 0 or width % ( self .vae_scale_factor * 2 ) != 0 :
473+ logger . warning (
474+ f"`height` and `width` have to be divisible by { self .vae_scale_factor * 2 } but are { height } and { width } . Dimensions will be resized accordingly "
473475 )
474476
475477 if callback_on_step_end_tensor_inputs is not None and not all (
@@ -578,8 +580,10 @@ def prepare_latents(
578580 f" size of { batch_size } . Make sure the batch size matches the length of the generators."
579581 )
580582
581- height = int (height ) // self .vae_scale_factor
582- width = int (width ) // self .vae_scale_factor
583+ # VAE applies 8x compression on images but we must also account for packing which requires
584+ # latent height and width to be divisible by 2.
585+ height = int (height ) // self .vae_scale_factor - ((int (height ) // self .vae_scale_factor ) % 2 )
586+ width = int (width ) // self .vae_scale_factor - ((int (width ) // self .vae_scale_factor ) % 2 )
583587
584588 shape = (batch_size , num_channels_latents , height , width )
585589 latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
@@ -624,8 +628,10 @@ def prepare_mask_latents(
624628 device ,
625629 generator ,
626630 ):
627- height = int (height ) // self .vae_scale_factor
628- width = int (width ) // self .vae_scale_factor
631+ # VAE applies 8x compression on images but we must also account for packing which requires
632+ # latent height and width to be divisible by 2.
633+ height = int (height ) // self .vae_scale_factor - ((int (height ) // self .vae_scale_factor ) % 2 )
634+ width = int (width ) // self .vae_scale_factor - ((int (width ) // self .vae_scale_factor ) % 2 )
629635 # resize the mask to latents shape as we concatenate the mask to the latents
630636 # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
631637 # and half precision
0 commit comments