Skip to content

WanVACEPipeline doesn't work with bfloat16/foat16 vae #12141

@SlimRG

Description

@SlimRG

Describe the bug

When I use bfloat16 or float16 VAE computing type - I have error of different input and model types.
On encode step.

Reproduction

# --- VAE / VAE ---
    vae = AutoencoderKLWan.from_pretrained(
        "Wan-AI/Wan2.1-VACE-14B-diffusers",
        subfolder="vae",
        torch_dtype=torch.float16,
    )
    vae.enable_tiling(TARGET_HEIGHT, int(TARGET_WIDTH // 1.5))

    # --- Text Encoder / Кодировщик текста ---
    # [TODO - 0001] No support at 10.08.2025 - https://github.com/huggingface/transformers/issues/40067
    # text_encoder = UMT5EncoderModel.from_pretrained(
    #     "city96/umt5-xxl-encoder-gguf",
    #     gguf_file="umt5-xxl-encoder-Q8_0.gguf",
    #     torch_dtype=torch.float16,
    # )

    # --- Transformer / Трансформер ---
    transformer = WanVACETransformer3DModel.from_single_file(
        "https://huggingface.co/QuantStack/Wan2.1_T2V_14B_FusionX_VACE-GGUF/blob/main/Wan2.1_T2V_14B_FusionX_VACE-Q6_K.gguf",
        quantization_config=GGUFQuantizationConfig(
            compute_dtype=torch.float16
        ),
        torch_dtype=torch.float16,
    )

    # --- Pipeline assembly / Сборка пайплайна ---
    pipe = WanVACEPipeline.from_pretrained(
        "Wan-AI/Wan2.1-VACE-14B-diffusers",
        vae=vae,
        #text_encoder=text_encoder, # [TODO - 0001]
        transformer=transformer,
        torch_dtype=torch.float16
    )

    # --- Scheduler / Планировщик ---
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)

    # Sage
    pipe.transformer.set_attention_backend("sage")  # stable for masks / стабильно для масок
    log.info("Attention backend set to 'sage' / Бэкенд внимания установлен: 'sage'")

    # VAE memory save
    pipe.enable_model_cpu_offload()

    log.info("WAN pipeline is ready / Пайплайн WAN готов")
    return pipe

Logs

System Info

  • 🤗 Diffusers version: 0.35.0.dev0
  • Platform: Windows-11-10.0.26100-SP0
  • Running on Google Colab?: No
  • Python version: 3.12.10
  • PyTorch version (GPU?): 2.7.1+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.34.3
  • Transformers version: 4.54.1
  • Accelerate version: 1.1.0
  • PEFT version: 0.15.2
  • Bitsandbytes version: 0.46.0
  • Safetensors version: 0.5.3
  • xFormers version: 0.0.31.post1
  • Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?: No
  • Using: accelerate launch

Who can help?

@DN6 @a-r-r-o-w
@sayakpaul @yiyixuxu

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions