-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When running Wan 2.2 TI2V 5B with tiled VAE, the VAE step is unsuccessful, resulting in:
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 2
Reproduction
import torch
import numpy as np
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
dtype = torch.bfloat16
device = "cuda"
model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
vae.enable_tiling(
tile_sample_min_height=256,
tile_sample_min_width=256,
tile_sample_stride_height=64,
tile_sample_stride_width=64
)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=dtype)
pipe.to(device)
height = 704
width = 1280
num_frames = 121
num_inference_steps = 50
guidance_scale = 5.0
prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
negative_prompt = "่ฒ่ฐ่ณไธฝ๏ผ่ฟๆ๏ผ้ๆ๏ผ็ป่ๆจก็ณไธๆธ
๏ผๅญๅน๏ผ้ฃๆ ผ๏ผไฝๅ๏ผ็ปไฝ๏ผ็ป้ข๏ผ้ๆญข๏ผๆดไฝๅ็ฐ๏ผๆๅทฎ่ดจ้๏ผไฝ่ดจ้๏ผJPEGๅ็ผฉๆฎ็๏ผไธ้็๏ผๆฎ็ผบ็๏ผๅคไฝ็ๆๆ๏ผ็ปๅพไธๅฅฝ็ๆ้จ๏ผ็ปๅพไธๅฅฝ็่ธ้จ๏ผ็ธๅฝข็๏ผๆฏๅฎน็๏ผๅฝขๆ็ธๅฝข็่ขไฝ๏ผๆๆ่ๅ๏ผ้ๆญขไธๅจ็็ป้ข๏ผๆไนฑ็่ๆฏ๏ผไธๆก่
ฟ๏ผ่ๆฏไบบๅพๅค๏ผๅ็่ตฐ"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
).frames[0]
export_to_video(output, "5bit2v_output.mp4", fps=24)
print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")Logs
Loading checkpoint shards: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 3/3 [00:01<00:00, 2.07it/s]
The config attributes {'clip_output': False} were passed to AutoencoderKLWan, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 5/5 [00:28<00:00, 5.60s/it]
Loading pipeline components...: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 5/5 [00:00<00:00, 7.77it/s]
100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 10/10 [01:26<00:00, 8.66s/it]
Traceback (most recent call last):
File "/workspace/Wan2.2/wan22tiledvae.py", line 75, in <module>
output = pipe(
^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/diffusers/pipelines/wan/pipeline_wan.py", line 645, in __call__
video = self.vae.decode(latents, return_dict=False)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
return method(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1248, in decode
decoded = self._decode(z).sample
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1204, in _decode
return self.tiled_decode(z, return_dict=return_dict)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1374, in tiled_decode
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 893, in forward
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 709, in forward
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 2System Info
Tested with multiple, no difference noted.
- ๐ค Diffusers version: 0.35.2
- Platform: Linux-6.8.0-65-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.12.3
- PyTorch version (GPU?): 2.8.0+cu128 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.35.3
- Transformers version: 4.51.3
- Accelerate version: 1.11.0
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.6.2
- xFormers version: not installed
- Accelerator: NVIDIA RTX A5000, 24564 MiB
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working