Skip to content

WanVACEPipeline - doesn't work with apply_group_offloading #12096

@SlimRG

Description

@SlimRG

Describe the bug

When enable apply_group_offloading:

Traceback (most recent call last):
  File "D:\Experiments\Video_Outpaint\2__Outpaint.py", line 165, in <module>
    out = pipe(
          ^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\pipelines\wan\pipeline_wan_vace.py", line 873, in __call__
    conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator, device)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\pipelines\wan\pipeline_wan_vace.py", line 532, in prepare_video_latents
    inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
                                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\utils\accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\autoencoders\autoencoder_kl_wan.py", line 1191, in encode
    h = self._encode(x)
        ^^^^^^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\autoencoders\autoencoder_kl_wan.py", line 1158, in _encode
    out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\autoencoders\autoencoder_kl_wan.py", line 593, in forward
    x = self.conv_in(x, feat_cache[idx])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\diffusers\models\autoencoders\autoencoder_kl_wan.py", line 176, in forward
    return super().forward(x)
           ^^^^^^^^^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\conv.py", line 725, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\conv.py", line 720, in _conv_forward
    return F.conv3d(
           ^^^^^^^^^
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Scripts\accelerate.exe\__main__.py", line 7, in <module>
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\accelerate\commands\accelerate_cli.py", line 48, in main
    args.func(args)
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\accelerate\commands\launch.py", line 1168, in launch_command
    simple_launcher(args)
  File "C:\Users\BBCCA\AppData\Local\Programs\Python\Python312\Lib\site-packages\accelerate\commands\launch.py", line 763, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['C:\\Users\\BBCCA\\AppData\\Local\\Programs\\Python\\Python312\\python.exe', '.\\2__Outpaint.py']' returned non-zero exit status 1.

Reproduction

onload_device = torch.device("cuda")
offload_device = torch.device("cpu")

vae = AutoencoderKLWan.from_pretrained(
    "Wan-AI/Wan2.1-VACE-14B-diffusers",
    subfolder="vae",
    torch_dtype=torch.float32
)

transformer = WanVACETransformer3DModel.from_single_file(
    "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q6_K.gguf",
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16
)

pipe = WanVACEPipeline.from_pretrained(
    "Wan-AI/Wan2.1-VACE-14B-diffusers",
    vae=vae,
    transformer=transformer,
    torch_dtype=torch.bfloat16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(
    pipe.scheduler.config, flow_shift=flow_shift
)

list(map(lambda module: apply_group_offloading(
    module,
    onload_device=onload_device,
    offload_device=offload_device,
    offload_type="block_level",
    use_stream=False,
    num_blocks_per_group = 2
), [pipe.text_encoder, pipe.transformer, pipe.vae]))

...

out = pipe(
            video=batch_canv,
            mask=batch_mask,
            prompt=prompts[idx],
            height=target_height,
            width=target_width,
            num_frames=batch_size,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale
        ).frames

Logs

Full code:

import cv2
import torch
import gc
import numpy as np
from PIL import Image, ImageDraw
from tqdm import tqdm
from diffusers import AutoencoderKL, AutoencoderKLWan, WanVACEPipeline, WanVACETransformer3DModel
from diffusers.quantizers.quantization_config import GGUFQuantizationConfig
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from diffusers.utils import export_to_video
from accelerate import Accelerator
from transformers import BlipProcessor, BlipForConditionalGeneration
from diffusers.hooks import apply_group_offloading

import torch

# === Параметры ===
input_path = "D:/Experiments/Video_Outpaint/2__scenes_all/video-01-005.mp4"
output_path = "D:/Experiments/Video_Outpaint/3__outpaint/video-01-005.mp4"
target_width, target_height = 1280, 720
num_inference_steps = 40
guidance_scale = 5.0
flow_shift = 5.0  # для плавности при движении

device = "cuda" if torch.cuda.is_available() else "cpu"

# === Считываем кадры из входного видео ===
cap = cv2.VideoCapture(input_path)
frames = []
fps = cap.get(cv2.CAP_PROP_FPS) or 30
while True:
    ret, frame = cap.read()
    if not ret:
        break
    frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
cap.release()

# === Генерация VAE-шумового canvas и масок ===
single_vae = AutoencoderKL.from_pretrained(
    "hackkhai/Flux.1-Dev-nf4-pkg",
    subfolder="vae",
    torch_dtype=torch.bfloat16
)
vae_scale = 8
latent_channels = single_vae.config.latent_channels
canvases, masks = [], []
for img in tqdm(frames, desc="Generating VAE noise canvases"):
    w, h = img.size
    dx = (target_width - w) // 2
    dy = (target_height - h) // 2
    latents = torch.randn(
        1, latent_channels, target_height // vae_scale, target_width // vae_scale,
        device=device, dtype=torch.bfloat16
    ) * single_vae.config.scaling_factor
    latents = latents.to(dtype=single_vae.dtype)
    with torch.no_grad():
        noise = single_vae.decode(latents).sample[0]
    noise_img = (
        (noise / 2 + 0.5).clamp(0, 1)
        .permute(1, 2, 0)
        .detach()
        .to(torch.float32)
        .cpu()
        .numpy()
    )
    canvas = Image.fromarray((noise_img * 255).astype(np.uint8))
    canvas.paste(img, (dx, dy))
    mask = Image.new("L", (target_width, target_height), 255)
    ImageDraw.Draw(mask).rectangle([dx, dy, dx + w, dy + h], fill=0)
    canvases.append(canvas)
    masks.append(mask)

del single_vae
gc.collect()
torch.cuda.empty_cache()

# === 4. Генерация подсказок (prompts) с BLIP ===
proc_blip = BlipProcessor.from_pretrained(
    "Salesforce/blip-image-captioning-large", use_fast=True
)
model_blip = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-large"
).eval().to(device)
prompts = []
exists = set()
batch_size = 5
# Проходим по кадрам с шагом 4 (каждый 4-й кадр)
for idx in tqdm(range(0, len(frames), batch_size-1), desc="Generating prompts"):
    img = frames[idx]

    # Получаем токены из BLIP
    inputs = proc_blip(images=img, return_tensors="pt").to(device)
    with torch.no_grad():
        output_ids = model_blip.generate(**inputs)
    text = proc_blip.decode(output_ids[0], skip_special_tokens=True)
    text = f" {text}"

    # Пропускаем, если уже было такое описание
    if text in exists:
        continue

    exists.add(text)
    prompts.append(text)
print(prompts)

del proc_blip, model_blip
gc.collect()
torch.cuda.empty_cache()

onload_device = torch.device("cuda")
offload_device = torch.device("cpu")

vae = AutoencoderKLWan.from_pretrained(
    "Wan-AI/Wan2.1-VACE-14B-diffusers",
    subfolder="vae",
    torch_dtype=torch.float32
)

transformer = WanVACETransformer3DModel.from_single_file(
    "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q6_K.gguf",
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16
)

pipe = WanVACEPipeline.from_pretrained(
    "Wan-AI/Wan2.1-VACE-14B-diffusers",
    vae=vae,
    transformer=transformer,
    torch_dtype=torch.bfloat16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(
    pipe.scheduler.config, flow_shift=flow_shift
)

list(map(lambda module: apply_group_offloading(
    module,
    onload_device=onload_device,
    offload_device=offload_device,
    offload_type="block_level",
    use_stream=False,
    num_blocks_per_group = 2
), [pipe.text_encoder, pipe.transformer, pipe.vae]))

# === 6. Генерация video outpaint пачками по 4 кадра с прогресс-баром ===
result_frames = []
idx = 0
prev_frame = canvases[0]
prev_mask  = masks[0]

with torch.no_grad():
    for i in tqdm(range(0, len(canvases), batch_size-1), desc="Outpainting batches"):
        # вырезаем очередной кусок по 5 кадра
        batch_canv = [prev_frame] + canvases[i : i + batch_size - 1]
        batch_mask = [prev_mask] + masks[i : i + batch_size - 1]
        batch_len = len(batch_canv)

        # паддинг до batch_size
        if batch_len < batch_size:
            last_c, last_m = batch_canv[-1], batch_mask[-1]
            for _ in range(5 - batch_len):
                batch_canv.append(last_c)
                batch_mask.append(last_m)

        # инференс
        out = pipe(
            video=batch_canv,
            mask=batch_mask,
            prompt=prompts[idx],
            height=target_height,
            width=target_width,
            num_frames=batch_size,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale
        ).frames

        # отбрасываем паддинговые кадры
        out = out[1:batch_len]
        result_frames.extend(out)

        idx += 1

# === 7. Сохранение видео ===
export_to_video(result_frames, output_path, fps=int(fps))
print(f"Готово: {output_path}")

System Info

  • 🤗 Diffusers version: 0.35.0.dev0
  • Platform: Windows-11-10.0.26100-SP0
  • Running on Google Colab?: No
  • Python version: 3.12.10
  • PyTorch version (GPU?): 2.7.1+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.34.3
  • Transformers version: 4.54.1
  • Accelerate version: 1.1.0
  • PEFT version: 0.15.2
  • Bitsandbytes version: 0.46.0
  • Safetensors version: 0.5.3
  • xFormers version: 0.0.31.post1
  • Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions