Skip to content

Wan2.2 TI2V-5B Tiled VAE Tensor size mismatchย #12529

@CalamitousFelicitousness

Description

Describe the bug

When running Wan 2.2 TI2V 5B with tiled VAE, the VAE step is unsuccessful, resulting in:
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 2

Reproduction

import torch
import numpy as np
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video

dtype = torch.bfloat16
device = "cuda"

model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"

vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
vae.enable_tiling(
    tile_sample_min_height=256,
    tile_sample_min_width=256,
    tile_sample_stride_height=64,
    tile_sample_stride_width=64
)

pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=dtype)
pipe.to(device)

height = 704
width = 1280
num_frames = 121
num_inference_steps = 50
guidance_scale = 5.0

prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
negative_prompt = "่‰ฒ่ฐƒ่‰ณไธฝ๏ผŒ่ฟ‡ๆ›๏ผŒ้™ๆ€๏ผŒ็ป†่Š‚ๆจก็ณŠไธๆธ…๏ผŒๅญ—ๅน•๏ผŒ้ฃŽๆ ผ๏ผŒไฝœๅ“๏ผŒ็”ปไฝœ๏ผŒ็”ป้ข๏ผŒ้™ๆญข๏ผŒๆ•ดไฝ“ๅ‘็ฐ๏ผŒๆœ€ๅทฎ่ดจ้‡๏ผŒไฝŽ่ดจ้‡๏ผŒJPEGๅŽ‹็ผฉๆฎ‹็•™๏ผŒไธ‘้™‹็š„๏ผŒๆฎ‹็ผบ็š„๏ผŒๅคšไฝ™็š„ๆ‰‹ๆŒ‡๏ผŒ็”ปๅพ—ไธๅฅฝ็š„ๆ‰‹้ƒจ๏ผŒ็”ปๅพ—ไธๅฅฝ็š„่„ธ้ƒจ๏ผŒ็•ธๅฝข็š„๏ผŒๆฏๅฎน็š„๏ผŒๅฝขๆ€็•ธๅฝข็š„่‚ขไฝ“๏ผŒๆ‰‹ๆŒ‡่žๅˆ๏ผŒ้™ๆญขไธๅŠจ็š„็”ป้ข๏ผŒๆ‚ไนฑ็š„่ƒŒๆ™ฏ๏ผŒไธ‰ๆก่…ฟ๏ผŒ่ƒŒๆ™ฏไบบๅพˆๅคš๏ผŒๅ€’็€่ตฐ"

output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=num_frames,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
).frames[0]

export_to_video(output, "5bit2v_output.mp4", fps=24)

print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")

Logs

Loading checkpoint shards: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 3/3 [00:01<00:00,  2.07it/s]
The config attributes {'clip_output': False} were passed to AutoencoderKLWan, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 5/5 [00:28<00:00,  5.60s/it]
Loading pipeline components...: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 5/5 [00:00<00:00,  7.77it/s]
100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 10/10 [01:26<00:00,  8.66s/it]
Traceback (most recent call last):
  File "/workspace/Wan2.2/wan22tiledvae.py", line 75, in <module>
    output = pipe(
             ^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/pipelines/wan/pipeline_wan.py", line 645, in __call__
    video = self.vae.decode(latents, return_dict=False)[0]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1248, in decode
    decoded = self._decode(z).sample
              ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1204, in _decode
    return self.tiled_decode(z, return_dict=return_dict)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1374, in tiled_decode
    decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 893, in forward
    x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 709, in forward
    x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
        ~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 2

System Info

Tested with multiple, no difference noted.

  • ๐Ÿค— Diffusers version: 0.35.2
  • Platform: Linux-6.8.0-65-generic-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.12.3
  • PyTorch version (GPU?): 2.8.0+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.35.3
  • Transformers version: 4.51.3
  • Accelerate version: 1.11.0
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.6.2
  • xFormers version: not installed
  • Accelerator: NVIDIA RTX A5000, 24564 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@sayakpaul @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions