Skip to content

Conversation

@oneflyingfish
Copy link

What does this PR do?

When I use enable_tiling() in autoencoder_kl_wan.AutoencoderKLWan, the inference would report compute error. I have identified the cause of this error and noticed other potential implementation issues during the repair process. Therefore, I have implemented the repair code.

Multiple distributions found for package optimum. Picked distribution: optimum
The config attributes {'clip_output': False} were passed to AutoencoderKLWan, but are not expected and will be ignored. Please verify your config.json configuration file.
output shape: torch.Size([1, 3, 81, 736, 1280])
fail to inference vae with vae.enable_tiling() Given groups=1, weight of size [160, 12, 3, 3, 3], expected input[1, 3, 3, 258, 258] to have 12 channels, but got 3 channels instead

Bug reproduction code:

from diffusers.models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
import torch
import os

dtype = torch.bfloat16
device = torch.device(f"cuda:0")
cpu_device = torch.device("cpu")

weight = "/path/to/vae"
vae = (
    AutoencoderKLWan.from_pretrained(
        weight,
        torch_dtype=dtype,
    )
    .eval()
    .to(dtype)
)

with torch.no_grad():
    torch.manual_seed(0)
    dummy_input = (torch.randn((1, 3, 81, 736, 1280),device=device,dtype=dtype)-0.5)/0.5    # B,C,F,H,W

    torch.manual_seed(0)
    # encode
    latent = vae.encode(dummy_input).latent_dist.mode() # type: torch.Tensor

    # decode
    gen_raw = vae.decode(latent, return_dict=False)[0]
    print(f"output shape: {gen_raw.shape}")
    
    # run tiling
    try:
      	torch.manual_seed(0)
        vae.enable_tiling()
        # encode
        latent = vae.encode(dummy_input).latent_dist.mode() # type: torch.Tensor

        # decode
        tile_gen = vae.decode(latent, return_dict=False)[0]
    except Exception as ex:
        print("fail to inference vae with vae.enable_tiling(), error info", ex)

Verification

via tool VCmpTool

比对

raw video:

10120244.mp4

gen video by tiled:

decode_gen_0_tiling.mp4

Fixes

  • shape error while use enable_tiling() that indicates the input channel number is incorrect
  • abnormal decode process in tiled_encode (without considering the first frame, default implementation is inconsistent), also cause shape errors
  • The positions of the quant_conv and post_quant_conv operators were adjusted to minimize the error as much as possible.

Before submitting

  • [N] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [Y] Did you read the contributor guideline?
  • [Y ] Did you read our philosophy doc (important for complex PRs)?
  • [N] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • [Y] Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • [Y] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@oneflyingfish
Copy link
Author

The video was generated for only 3 seconds because the test only captured the first 81 frames of the original video.

@oneflyingfish
Copy link
Author

@sayakpaul @yiyixuxu @DN6 sorry, forget to remind you in pr doc, help to review please

@oneflyingfish oneflyingfish changed the title fix wan vae tiling bug fix autoencoder_kl_wan.AutoencoderKLWan.py tiling bugs Sep 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant