@@ -207,7 +207,8 @@ def __init__(
207207 transformer = transformer ,
208208 scheduler = scheduler ,
209209 )
210- self .vae_scale_factor = 8
210+ self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 ) if getattr (self , "vae" , None ) else 8
211+ self .latent_channels = self .vae .config .latent_channels if getattr (self , "vae" , None ) else 16
211212 self .default_sample_size = (
212213 self .transformer .config .sample_size
213214 if hasattr (self , "transformer" ) and self .transformer is not None
@@ -530,7 +531,7 @@ def prepare_latents(
530531
531532 if image is not None :
532533 image = image .to (device = device , dtype = dtype )
533- if image .shape [1 ] != self .transformer . config . in_channels :
534+ if image .shape [1 ] != self .latent_channels :
534535 image_latents = self ._encode_vae_image (image = image , generator = generator )
535536 else :
536537 image_latents = image
@@ -743,8 +744,7 @@ def __call__(
743744 system_prompt = system_prompt ,
744745 )
745746
746- latent_channels = self .transformer .config .in_channels
747- if image is not None and not (isinstance (image , torch .Tensor ) and image .size (1 ) == latent_channels ):
747+ if image is not None and not (isinstance (image , torch .Tensor ) and image .size (1 ) == self .latent_channels ):
748748 img = image [0 ] if isinstance (image , list ) else image
749749 image_height , image_width = self .image_processor .get_default_height_width (img )
750750 image_width = image_width // multiple_of * multiple_of
@@ -753,10 +753,11 @@ def __call__(
753753 image = self .image_processor .preprocess (image , image_height , image_width )
754754
755755 # 4. Prepare latents.
756+ num_channels_latents = self .transformer .config .in_channels
756757 latents , image_latents = self .prepare_latents (
757758 image ,
758759 batch_size * num_images_per_prompt ,
759- latent_channels ,
760+ num_channels_latents ,
760761 height ,
761762 width ,
762763 prompt_embeds .dtype ,
0 commit comments