Skip to content

Commit 8304438

Browse files
charliewwdevclaude
andcommitted
V4: StoryEngine + audio-first workflow + lip sync + character portraits
Add the complete V4 pipeline for multi-shot story-driven video generation: - Core: StoryEngine, CharacterManager, ShotScheduler with audio-first duration computation (narration WAV → frame count, 4N+1 aligned) - Backends: Wan 2.2 TI2V-5B (dual-transformer, MPS fallback) + Animate - PostProcess: RIFE interpolation, Real-ESRGAN upscale, F5-TTS audio with per-shot voice profiles, lip sync (MuseTalk/JoyVASA), and VideoCompositor with timed audio alignment via ffmpeg adelay - CLI: story.py (storyboard pipeline), generate_portraits.py (I2V refs) - Storyboards: xianxia demo, fanren trailer V1 + V2 (12 shots with narration, character consistency via I2V, selective lip sync) - Tests: 157 tests (77 V4 + 80 V3), all passing Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0c807d0 commit 8304438

22 files changed

+4863
-8
lines changed

animatediff/backends/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
44
Available backends:
55
- wan: Wan 2.1 (Alibaba, 1.3B-14B, best quality-to-VRAM ratio)
6+
- wan22: Wan 2.2 (MoE dual-path, LoRA support, TI2V-5B for MPS)
7+
- wan22_animate: Wan 2.2 Animate (character animation + replacement)
68
- hunyuan: HunyuanVideo (Tencent, 8.3B, high quality)
79
- cogvideo: CogVideoX (THU, 2B-5B, lightest)
810
- ltx: LTX-Video (Lightricks, real-time capable)
@@ -13,6 +15,8 @@
1315

1416
BACKEND_REGISTRY: Dict[str, str] = {
1517
"wan": "animatediff.backends.wan.WanBackend",
18+
"wan22": "animatediff.backends.wan22.Wan22Backend",
19+
"wan22_animate": "animatediff.backends.wan22_animate.Wan22AnimateBackend",
1620
"hunyuan": "animatediff.backends.hunyuan.HunyuanBackend",
1721
"cogvideo": "animatediff.backends.cogvideo.CogVideoBackend",
1822
"ltx": "animatediff.backends.ltx.LTXBackend",

animatediff/backends/wan22.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
"""
2+
Wan 2.2 Backend — wraps diffusers WanPipeline for Wan 2.2 models.
3+
4+
Model variants (MoE dual-path architecture):
5+
- Wan-AI/Wan2.2-T2V-A14B-Diffusers (27B total / 14B active, CUDA only — FP8 MoE)
6+
- Wan-AI/Wan2.2-I2V-A14B-Diffusers (27B total / 14B active, CUDA only — FP8 MoE)
7+
- Wan-AI/Wan2.2-TI2V-5B-Diffusers (5B dense, works on MPS — unified T2V+I2V)
8+
9+
Key differences from Wan 2.1:
10+
- Two-stage denoising: high-noise transformer + low-noise transformer (MoE)
11+
- guidance_scale_2 parameter for the second transformer
12+
- Dual-transformer LoRA loading (load_into_transformer_2=True)
13+
- TI2V-5B: dense model, accepts optional image input, 24 fps, 121 frames
14+
15+
NOTE: A14B models use FP8 (Float8_e4m3fn) internally in the MoE experts.
16+
MPS does NOT support FP8 — only TI2V-5B works on Apple Silicon.
17+
"""
18+
19+
import logging
20+
from typing import Optional, List
21+
22+
import torch
23+
from PIL import Image
24+
25+
from animatediff.core.base_pipeline import BasePipeline, VideoOutput
26+
from animatediff.core.quantization import get_quantization_config
27+
28+
logger = logging.getLogger(__name__)
29+
30+
WAN22_T2V_MODELS = {
31+
"A14B": "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
32+
"5B": "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
33+
}
34+
35+
WAN22_I2V_MODELS = {
36+
"A14B": "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
37+
"5B": "Wan-AI/Wan2.2-TI2V-5B-Diffusers", # TI2V-5B handles both T2V and I2V
38+
}
39+
40+
# TI2V-5B defaults differ from A14B
41+
MODEL_DEFAULTS = {
42+
"A14B": dict(width=1280, height=720, num_frames=81, fps=16, guidance_scale=4.0, guidance_scale_2=3.0, steps=40),
43+
"5B": dict(width=1280, height=704, num_frames=121, fps=24, guidance_scale=5.0, guidance_scale_2=None, steps=50),
44+
}
45+
46+
47+
class Wan22Backend(BasePipeline):
48+
backend_name = "wan22"
49+
50+
def __init__(self, pipe, model_variant: str = "5B", lora_names: Optional[List[str]] = None):
51+
self.pipe = pipe
52+
self.model_variant = model_variant
53+
self.lora_names = lora_names or []
54+
self._defaults = MODEL_DEFAULTS.get(model_variant, MODEL_DEFAULTS["5B"])
55+
56+
@classmethod
57+
def load(
58+
cls,
59+
model_path: Optional[str] = None,
60+
torch_dtype: torch.dtype = torch.bfloat16,
61+
device: str = "cuda",
62+
quantization: str = "none",
63+
offload_strategy: str = "none",
64+
enable_vae_slicing: bool = True,
65+
enable_vae_tiling: bool = False,
66+
model_variant: str = "5B",
67+
mode: str = "t2v",
68+
lora_paths: Optional[List[str]] = None,
69+
lora_scales: Optional[List[float]] = None,
70+
**kwargs,
71+
) -> "Wan22Backend":
72+
from diffusers import WanPipeline, WanImageToVideoPipeline, AutoencoderKLWan
73+
74+
# Resolve model path
75+
if model_path is None:
76+
if mode == "i2v" and model_variant != "5B":
77+
model_path = WAN22_I2V_MODELS.get(model_variant, WAN22_I2V_MODELS["A14B"])
78+
else:
79+
model_path = WAN22_T2V_MODELS.get(model_variant, WAN22_T2V_MODELS["5B"])
80+
81+
# MPS safety: A14B uses FP8 internally, which MPS doesn't support
82+
if device == "mps" and model_variant == "A14B":
83+
logger.warning("Wan 2.2 A14B uses FP8 MoE experts — not supported on MPS. Falling back to TI2V-5B.")
84+
model_variant = "5B"
85+
model_path = WAN22_T2V_MODELS["5B"]
86+
87+
# MPS requires float32 for Wan models
88+
if device == "mps":
89+
torch_dtype = torch.float32
90+
logger.info("MPS detected: using float32 (float16/bfloat16 not fully supported for Wan on MPS)")
91+
92+
logger.info(f"Loading Wan 2.2 {model_variant} from {model_path} (dtype={torch_dtype}, quant={quantization})")
93+
94+
# VAE must always be float32 for Wan
95+
vae = AutoencoderKLWan.from_pretrained(model_path, subfolder="vae", torch_dtype=torch.float32)
96+
97+
# Quantization config — for A14B, quantize both transformers
98+
components = ["transformer", "transformer_2"] if model_variant == "A14B" else ["transformer"]
99+
quant_config = get_quantization_config(quantization, components=components)
100+
101+
load_kwargs = dict(torch_dtype=torch_dtype, vae=vae)
102+
if quant_config is not None:
103+
load_kwargs["quantization_config"] = quant_config
104+
105+
# Choose pipeline class
106+
if mode == "i2v" and model_variant == "A14B":
107+
PipelineClass = WanImageToVideoPipeline
108+
else:
109+
# TI2V-5B uses WanPipeline for both T2V and I2V
110+
PipelineClass = WanPipeline
111+
112+
pipe = PipelineClass.from_pretrained(model_path, **load_kwargs)
113+
114+
# Fix: transformers 5.x UMT5 embed_tokens zero-weight bug (same as Wan 2.1)
115+
te = pipe.text_encoder
116+
if (hasattr(te, "shared") and hasattr(te, "encoder")
117+
and hasattr(te.encoder, "embed_tokens")
118+
and te.encoder.embed_tokens.weight.abs().sum().item() == 0
119+
and te.shared.weight.abs().sum().item() > 0):
120+
logger.warning("Fixing UMT5 embed_tokens: binding shared.weight -> encoder.embed_tokens.weight")
121+
te.encoder.embed_tokens.weight = te.shared.weight
122+
123+
instance = cls(pipe, model_variant=model_variant)
124+
125+
# Load LoRAs if provided
126+
if lora_paths:
127+
instance._load_loras(lora_paths, lora_scales or [1.0] * len(lora_paths))
128+
129+
# Apply offloading
130+
if offload_strategy != "none":
131+
instance._apply_offloading(pipe, offload_strategy, device=device)
132+
else:
133+
pipe.to(device)
134+
135+
instance._apply_vae_opts(pipe, slicing=enable_vae_slicing, tiling=enable_vae_tiling)
136+
137+
return instance
138+
139+
def _load_loras(self, lora_paths: List[str], lora_scales: List[float]):
140+
"""Load LoRA weights. For A14B, supports dual-transformer LoRA loading."""
141+
for i, (path, scale) in enumerate(zip(lora_paths, lora_scales)):
142+
adapter_name = f"lora_{i}"
143+
144+
# Detect if this is a dual-transformer LoRA (by filename convention)
145+
is_low_noise = "_LOW" in path or "_low" in path or "transformer_2" in path
146+
147+
load_kwargs = dict(adapter_name=adapter_name)
148+
if is_low_noise and self.model_variant == "A14B":
149+
load_kwargs["load_into_transformer_2"] = True
150+
logger.info(f"Loading LoRA into transformer_2 (low-noise): {path} (scale={scale})")
151+
else:
152+
logger.info(f"Loading LoRA into transformer (high-noise): {path} (scale={scale})")
153+
154+
# Handle both repo IDs and local paths
155+
if "/" in path and not path.startswith("/") and not path.startswith("."):
156+
# Looks like a HuggingFace repo ID — split off weight_name
157+
parts = path.rsplit("/", 1)
158+
if len(parts) == 2 and "." in parts[1]:
159+
self.pipe.load_lora_weights(parts[0], weight_name=parts[1], **load_kwargs)
160+
else:
161+
self.pipe.load_lora_weights(path, **load_kwargs)
162+
else:
163+
self.pipe.load_lora_weights(path, **load_kwargs)
164+
165+
self.lora_names.append(adapter_name)
166+
167+
if self.lora_names:
168+
scales = lora_scales[:len(self.lora_names)]
169+
self.pipe.set_adapters(self.lora_names, adapter_weights=scales)
170+
logger.info(f"Activated LoRAs: {self.lora_names} with scales {scales}")
171+
172+
@torch.no_grad()
173+
def generate(
174+
self,
175+
prompt: str,
176+
negative_prompt: str = "",
177+
width: int = 0,
178+
height: int = 0,
179+
num_frames: int = 0,
180+
num_inference_steps: int = 0,
181+
guidance_scale: float = 0,
182+
seed: int = -1,
183+
image: Optional[Image.Image] = None,
184+
guidance_scale_2: Optional[float] = None,
185+
**kwargs,
186+
) -> VideoOutput:
187+
d = self._defaults
188+
width = width or d["width"]
189+
height = height or d["height"]
190+
num_frames = num_frames or d["num_frames"]
191+
num_inference_steps = num_inference_steps or d["steps"]
192+
guidance_scale = guidance_scale or d["guidance_scale"]
193+
194+
gen_device = "cpu" if self.pipe.device.type == "cpu" else self.pipe.device
195+
generator = self._make_generator(seed, gen_device)
196+
197+
pipe_kwargs = dict(
198+
prompt=prompt,
199+
negative_prompt=negative_prompt or None,
200+
width=width,
201+
height=height,
202+
num_frames=num_frames,
203+
num_inference_steps=num_inference_steps,
204+
guidance_scale=guidance_scale,
205+
generator=generator,
206+
output_type="pil",
207+
)
208+
209+
# A14B MoE: separate guidance for the second transformer
210+
gs2 = guidance_scale_2 if guidance_scale_2 is not None else d.get("guidance_scale_2")
211+
if gs2 is not None and self.model_variant == "A14B":
212+
pipe_kwargs["guidance_scale_2"] = gs2
213+
214+
# Image-to-video (TI2V-5B accepts image as optional input)
215+
if image is not None:
216+
pipe_kwargs["image"] = image
217+
218+
output = self.pipe(**pipe_kwargs)
219+
frames = output.frames[0]
220+
221+
return VideoOutput(
222+
frames=frames,
223+
fps=d["fps"],
224+
seed=seed,
225+
backend=self.backend_name,
226+
metadata={
227+
"model_variant": self.model_variant,
228+
"loras": self.lora_names,
229+
},
230+
)

0 commit comments

Comments
 (0)