Skip to content

Commit e3ccdb8

Browse files
committed
fix: correct init vae_scale_factor and add latent_chnnels
1 parent 2bf6677 commit e3ccdb8

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/diffusers/pipelines/lumina2/pipeline_lumina2_accessory.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ def __init__(
207207
transformer=transformer,
208208
scheduler=scheduler,
209209
)
210-
self.vae_scale_factor = 8
210+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
211+
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
211212
self.default_sample_size = (
212213
self.transformer.config.sample_size
213214
if hasattr(self, "transformer") and self.transformer is not None
@@ -530,7 +531,7 @@ def prepare_latents(
530531

531532
if image is not None:
532533
image = image.to(device=device, dtype=dtype)
533-
if image.shape[1] != self.transformer.config.in_channels:
534+
if image.shape[1] != self.latent_channels:
534535
image_latents = self._encode_vae_image(image=image, generator=generator)
535536
else:
536537
image_latents = image
@@ -743,8 +744,7 @@ def __call__(
743744
system_prompt=system_prompt,
744745
)
745746

746-
latent_channels = self.transformer.config.in_channels
747-
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == latent_channels):
747+
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
748748
img = image[0] if isinstance(image, list) else image
749749
image_height, image_width = self.image_processor.get_default_height_width(img)
750750
image_width = image_width // multiple_of * multiple_of
@@ -753,10 +753,11 @@ def __call__(
753753
image = self.image_processor.preprocess(image, image_height, image_width)
754754

755755
# 4. Prepare latents.
756+
num_channels_latents = self.transformer.config.in_channels
756757
latents, image_latents = self.prepare_latents(
757758
image,
758759
batch_size * num_images_per_prompt,
759-
latent_channels,
760+
num_channels_latents,
760761
height,
761762
width,
762763
prompt_embeds.dtype,

0 commit comments

Comments
 (0)