Skip to content

Commit 359151d

Browse files
[Feature] Add wan2.2 5b i2v (#760)
Co-authored-by: SolitaryThinker <[email protected]>
1 parent ce67cd3 commit 359151d

File tree

11 files changed

+282
-20
lines changed

11 files changed

+282
-20
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from fastvideo import VideoGenerator
2+
3+
OUTPUT_PATH = "video_samples_wan2_2_5B_ti2v"
4+
def main():
5+
# FastVideo will automatically use the optimal default arguments for the
6+
# model.
7+
# If a local path is provided, FastVideo will make a best effort
8+
# attempt to identify the optimal arguments.
9+
model_name = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
10+
generator = VideoGenerator.from_pretrained(
11+
model_name,
12+
# FastVideo will automatically handle distributed setup
13+
num_gpus=1,
14+
use_fsdp_inference=True,
15+
dit_cpu_offload=True,
16+
vae_cpu_offload=False,
17+
text_encoder_cpu_offload=True,
18+
pin_cpu_memory=True, # set to false if low CPU RAM or hit obscure "CUDA error: Invalid argument"
19+
# image_encoder_cpu_offload=False,
20+
)
21+
22+
# I2V is triggered just by passing in an image_path argument
23+
prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
24+
image_path = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG"
25+
video = generator.generate_video(prompt, output_path=OUTPUT_PATH, save_video=True, image_path=image_path)
26+
27+
# Generate another video with a different prompt, without reloading the
28+
# model!
29+
30+
# T2V mode
31+
prompt2 = (
32+
"A majestic lion strides across the golden savanna, its powerful frame "
33+
"glistening under the warm afternoon sun. The tall grass ripples gently in "
34+
"the breeze, enhancing the lion's commanding presence. The tone is vibrant, "
35+
"embodying the raw energy of the wild. Low angle, steady tracking shot, "
36+
"cinematic.")
37+
video2 = generator.generate_video(prompt2, output_path=OUTPUT_PATH, save_video=True)
38+
39+
40+
if __name__ == "__main__":
41+
main()

fastvideo/configs/pipelines/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ class PipelineConfig:
8585
# DMD parameters
8686
dmd_denoising_steps: list[int] | None = field(default=None)
8787

88+
# Wan2.2 TI2V parameters
89+
ti2v_task: bool = False
90+
8891
# Compilation
8992
# enable_torch_compile: bool = False
9093

fastvideo/configs/pipelines/registry.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
from fastvideo.configs.pipelines.base import PipelineConfig
88
from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig
99
from fastvideo.configs.pipelines.stepvideo import StepVideoT2VConfig
10-
from fastvideo.configs.pipelines.wan import (FastWan2_1_T2V_480P_Config,
11-
FastWan2_2_TI2V_5B_Config,
12-
SelfForcingWanT2V480PConfig,
13-
WanI2V480PConfig, WanI2V720PConfig,
14-
WanT2V480PConfig, WanT2V720PConfig)
10+
11+
# isort: off
12+
from fastvideo.configs.pipelines.wan import (
13+
FastWan2_1_T2V_480P_Config, FastWan2_2_TI2V_5B_Config,
14+
SelfForcingWanT2V480PConfig, Wan2_2_I2V_A14B_Config, Wan2_2_T2V_A14B_Config,
15+
Wan2_2_TI2V_5B_Config, WanI2V480PConfig, WanI2V720PConfig, WanT2V480PConfig,
16+
WanT2V720PConfig)
17+
# isort: on
1518
from fastvideo.logger import init_logger
1619
from fastvideo.utils import (maybe_download_model_index,
1720
verify_model_config_and_directory)
@@ -32,10 +35,10 @@
3235
"FastVideo/FastWan2.2-TI2V-5B-Diffusers": FastWan2_2_TI2V_5B_Config,
3336
"FastVideo/stepvideo-t2v-diffusers": StepVideoT2VConfig,
3437
"FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers": WanT2V720PConfig,
35-
"Wan-AI/Wan2.2-TI2V-5B-Diffusers": WanT2V720PConfig,
36-
"Wan-AI/Wan2.2-T2V-A14B-Diffusers": WanT2V480PConfig,
37-
"Wan-AI/Wan2.2-I2V-A14B-Diffusers": WanI2V480PConfig,
3838
"wlsaidhi/SFWan2.1-T2V-1.3B-Diffusers": SelfForcingWanT2V480PConfig,
39+
"Wan-AI/Wan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_Config,
40+
"Wan-AI/Wan2.2-T2V-A14B-Diffusers": Wan2_2_T2V_A14B_Config,
41+
"Wan-AI/Wan2.2-I2V-A14B-Diffusers": Wan2_2_I2V_A14B_Config,
3942
# Add other specific weight variants
4043
}
4144

fastvideo/configs/sample/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# isort: off
1111
from fastvideo.configs.sample.wan import (
1212
FastWanT2V480PConfig,
13+
Wan2_1_Fun_1_3B_InP_SamplingParam,
1314
Wan2_2_I2V_A14B_SamplingParam,
1415
Wan2_2_T2V_A14B_SamplingParam,
1516
Wan2_2_TI2V_5B_SamplingParam,
@@ -36,6 +37,8 @@
3637
"Wan-AI/Wan2.1-T2V-14B-Diffusers": WanT2V_14B_SamplingParam,
3738
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers": WanI2V_14B_480P_SamplingParam,
3839
"Wan-AI/Wan2.1-I2V-14B-720P-Diffusers": WanI2V_14B_720P_SamplingParam,
40+
"weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers":
41+
Wan2_1_Fun_1_3B_InP_SamplingParam,
3942

4043
# Wan2.2
4144
"Wan-AI/Wan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_SamplingParam,

fastvideo/configs/sample/wan.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,21 @@ class FastWanT2V480PConfig(WanT2V_1_3B_SamplingParam):
107107
fps: int = 16
108108

109109

110+
# =============================================
111+
# ============= Wan2.1 Fun Models =============
112+
# =============================================
113+
@dataclass
114+
class Wan2_1_Fun_1_3B_InP_SamplingParam(SamplingParam):
115+
"""Sampling parameters for Wan2.1 Fun 1.3B InP model."""
116+
height: int = 480
117+
width: int = 832
118+
num_frames: int = 81
119+
fps: int = 16
120+
negative_prompt: str | None = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
121+
guidance_scale: float = 6.0
122+
num_inference_steps: int = 50
123+
124+
110125
# =============================================
111126
# ============= Wan2.2 TI2V Models =============
112127
# =============================================

fastvideo/layers/visual_embedding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,16 @@ def __init__(
8686
dtype=dtype)
8787
self.freq_dtype = freq_dtype
8888

89-
def forward(self, t: torch.Tensor) -> torch.Tensor:
89+
def forward(self,
90+
t: torch.Tensor,
91+
timestep_seq_len: int | None = None) -> torch.Tensor:
9092
t_freq = timestep_embedding(t,
9193
self.frequency_embedding_size,
9294
self.max_period,
9395
dtype=self.freq_dtype).to(
9496
self.mlp.fc_in.weight.dtype)
97+
if timestep_seq_len is not None:
98+
t_freq = t_freq.unflatten(0, (1, timestep_seq_len))
9599
# t_freq = t_freq.to(self.mlp.fc_in.weight.dtype)
96100
t_emb = self.mlp(t_freq)
97101
return t_emb

fastvideo/models/dits/wanvideo.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ def forward(
8181
timestep: torch.Tensor,
8282
encoder_hidden_states: torch.Tensor,
8383
encoder_hidden_states_image: torch.Tensor | None = None,
84+
timestep_seq_len: int | None = None,
8485
):
85-
temb = self.time_embedder(timestep)
86+
temb = self.time_embedder(timestep, timestep_seq_len)
8687
timestep_proj = self.time_modulation(temb)
8788

8889
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
@@ -319,9 +320,24 @@ def forward(
319320
bs, seq_length, _ = hidden_states.shape
320321
orig_dtype = hidden_states.dtype
321322
# assert orig_dtype != torch.float32
322-
e = self.scale_shift_table + temb.float()
323-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk(
324-
6, dim=1)
323+
324+
if temb.dim() == 4:
325+
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
326+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
327+
self.scale_shift_table.unsqueeze(0) + temb.float()
328+
).chunk(6, dim=2)
329+
# batch_size, seq_len, 1, inner_dim
330+
shift_msa = shift_msa.squeeze(2)
331+
scale_msa = scale_msa.squeeze(2)
332+
gate_msa = gate_msa.squeeze(2)
333+
c_shift_msa = c_shift_msa.squeeze(2)
334+
c_scale_msa = c_scale_msa.squeeze(2)
335+
c_gate_msa = c_gate_msa.squeeze(2)
336+
else:
337+
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
338+
e = self.scale_shift_table + temb.float()
339+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk(
340+
6, dim=1)
325341
assert shift_msa.dtype == torch.float32
326342

327343
# 1. Self-attention
@@ -649,9 +665,21 @@ def forward(self,
649665
hidden_states = self.patch_embedding(hidden_states)
650666
hidden_states = hidden_states.flatten(2).transpose(1, 2)
651667

668+
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
669+
if timestep.dim() == 2:
670+
ts_seq_len = timestep.shape[1]
671+
timestep = timestep.flatten() # batch_size * seq_len
672+
else:
673+
ts_seq_len = None
674+
652675
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
653-
timestep, encoder_hidden_states, encoder_hidden_states_image)
654-
timestep_proj = timestep_proj.unflatten(1, (6, -1))
676+
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len)
677+
if ts_seq_len is not None:
678+
# batch_size, seq_len, 6, inner_dim
679+
timestep_proj = timestep_proj.unflatten(2, (6, -1))
680+
else:
681+
# batch_size, 6, inner_dim
682+
timestep_proj = timestep_proj.unflatten(1, (6, -1))
655683

656684
if encoder_hidden_states_image is not None:
657685
encoder_hidden_states = torch.concat(
@@ -688,8 +716,15 @@ def forward(self,
688716
if enable_teacache:
689717
self.maybe_cache_states(hidden_states, original_hidden_states)
690718
# 5. Output norm, projection & unpatchify
691-
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2,
692-
dim=1)
719+
if temb.dim() == 3:
720+
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
721+
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
722+
shift = shift.squeeze(2)
723+
scale = scale.squeeze(2)
724+
else:
725+
# batch_size, inner_dim
726+
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
727+
693728
hidden_states = self.norm_out(hidden_states, shift, scale)
694729
hidden_states = self.proj_out(hidden_states)
695730

@@ -793,3 +828,4 @@ def retrieve_cached_states(self,
793828
return hidden_states + self.previous_residual_even
794829
else:
795830
return hidden_states + self.previous_residual_odd
831+

fastvideo/pipelines/basic/wan/wan_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
6363
transformer=self.get_module("transformer"),
6464
transformer_2=self.get_module("transformer_2", None),
6565
scheduler=self.get_module("scheduler"),
66+
vae=self.get_module("vae"),
6667
pipeline=self))
6768

6869
self.add_stage(stage_name="decoding_stage",

fastvideo/pipelines/stages/denoising.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import inspect
7+
import math
78
import weakref
89
from collections.abc import Iterable
910
from typing import Any
@@ -30,7 +31,7 @@
3031
from fastvideo.pipelines.stages.validators import StageValidators as V
3132
from fastvideo.pipelines.stages.validators import VerificationResult
3233
from fastvideo.platforms import AttentionBackendEnum
33-
from fastvideo.utils import dict_to_3d_list
34+
from fastvideo.utils import dict_to_3d_list, masks_like
3435

3536
try:
3637
from fastvideo.attention.backends.sliding_tile_attn import (
@@ -61,11 +62,13 @@ def __init__(self,
6162
transformer,
6263
scheduler,
6364
pipeline=None,
64-
transformer_2=None) -> None:
65+
transformer_2=None,
66+
vae=None) -> None:
6567
super().__init__()
6668
self.transformer = transformer
6769
self.transformer_2 = transformer_2
6870
self.scheduler = scheduler
71+
self.vae = vae
6972
self.pipeline = weakref.ref(pipeline) if pipeline else None
7073
attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
7174
self.attn_backend = get_attn_backend(
@@ -194,6 +197,44 @@ def forward(
194197
boundary_timestep = fastvideo_args.boundary_ratio * self.scheduler.num_train_timesteps
195198
else:
196199
boundary_timestep = None
200+
latent_model_input = latents.to(target_dtype)
201+
assert latent_model_input.shape[0] == 1, "only support batch size 1"
202+
203+
if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
204+
# TI2V directly replaces the first frame of the latent with
205+
# the image latent instead of appending along the channel dim
206+
assert batch.image_latent is None, "TI2V task should not have image latents"
207+
assert self.vae is not None, "VAE is not provided for TI2V task"
208+
z = self.vae.encode(batch.pil_image).mean.float()
209+
if (hasattr(self.vae, "shift_factor")
210+
and self.vae.shift_factor is not None):
211+
if isinstance(self.vae.shift_factor, torch.Tensor):
212+
z -= self.vae.shift_factor.to(z.device, z.dtype)
213+
else:
214+
z -= self.vae.shift_factor
215+
216+
if isinstance(self.vae.scaling_factor, torch.Tensor):
217+
z = z * self.vae.scaling_factor.to(z.device, z.dtype)
218+
else:
219+
z = z * self.vae.scaling_factor
220+
221+
latent_model_input = latent_model_input.squeeze(0)
222+
_, mask2 = masks_like([latent_model_input], zero=True)
223+
224+
latent_model_input = (1. -
225+
mask2[0]) * z + mask2[0] * latent_model_input
226+
# latent_model_input = latent_model_input.unsqueeze(0)
227+
latent_model_input = latent_model_input.to(get_local_torch_device())
228+
latents = latent_model_input
229+
F = batch.num_frames
230+
temporal_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_temporal
231+
spatial_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
232+
patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
233+
seq_len = ((F - 1) // temporal_scale +
234+
1) * (batch.height // spatial_scale) * (
235+
batch.width // spatial_scale) // (patch_size[1] *
236+
patch_size[2])
237+
seq_len = int(math.ceil(seq_len / sp_world_size)) * sp_world_size
197238

198239
# Run denoising loop
199240
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -218,19 +259,32 @@ def forward(
218259
self.transformer.to('cpu')
219260
current_model = self.transformer_2
220261
current_guidance_scale = batch.guidance_scale_2
262+
assert current_model is not None, "current_model is None"
221263

222264
# Expand latents for I2V
223265
latent_model_input = latents.to(target_dtype)
224266
if batch.image_latent is not None:
267+
assert not fastvideo_args.pipeline_config.ti2v_task, "image latents should not be provided for TI2V task"
225268
latent_model_input = torch.cat(
226269
[latent_model_input, batch.image_latent],
227270
dim=1).to(target_dtype)
271+
if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
272+
timestep = torch.stack([t]).to(get_local_torch_device())
273+
temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
274+
temp_ts = torch.cat([
275+
temp_ts,
276+
temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
277+
])
278+
timestep = temp_ts.unsqueeze(0)
279+
t_expand = timestep.repeat(latent_model_input.shape[0], 1)
280+
else:
281+
t_expand = t.repeat(latent_model_input.shape[0])
282+
228283
assert torch.isnan(latent_model_input).sum() == 0
229284
latent_model_input = self.scheduler.scale_model_input(
230285
latent_model_input, t)
231286

232287
# Prepare inputs for transformer
233-
t_expand = t.repeat(latent_model_input.shape[0])
234288
guidance_expand = (
235289
torch.tensor(
236290
[fastvideo_args.pipeline_config.embedded_cfg_scale] *
@@ -330,6 +384,11 @@ def forward(
330384
latents,
331385
**extra_step_kwargs,
332386
return_dict=False)[0]
387+
if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
388+
latents = latents.squeeze(0)
389+
latents = (1. - mask2[0]) * z + mask2[0] * latents
390+
# latents = latents.unsqueeze(0)
391+
333392
# Update progress bar
334393
if i == len(timesteps) - 1 or (
335394
(i + 1) > num_warmup_steps and

0 commit comments

Comments
 (0)