Skip to content

Commit ce67cd3

Browse files
[Feat] Support Self-Forcing's Causal Inference for Wan2.1 T2V 1.3B (#766)
Co-authored-by: SolitaryThinker <[email protected]>
1 parent 7c554e5 commit ce67cd3

File tree

19 files changed

+1412
-30
lines changed

19 files changed

+1412
-30
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
import time
3+
from fastvideo import VideoGenerator, SamplingParam
4+
5+
OUTPUT_PATH = "video_samples_causal"
6+
def main():
7+
# FastVideo will automatically use the optimal default arguments for the
8+
# model.
9+
# If a local path is provided, FastVideo will make a best effort
10+
# attempt to identify the optimal arguments.
11+
model_name = "wlsaidhi/SFWan2.1-T2V-1.3B-Diffusers"
12+
generator = VideoGenerator.from_pretrained(
13+
model_name,
14+
# FastVideo will automatically handle distributed setup
15+
num_gpus=1,
16+
use_fsdp_inference=True,
17+
text_encoder_cpu_offload=False,
18+
dit_cpu_offload=False,
19+
)
20+
21+
sampling_param = SamplingParam.from_pretrained(model_name)
22+
23+
prompt = (
24+
"A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes "
25+
"wide with interest. The playful yet serene atmosphere is complemented by soft "
26+
"natural light filtering through the petals. Mid-shot, warm and cheerful tones."
27+
)
28+
video = generator.generate_video(prompt, output_path=OUTPUT_PATH, save_video=True, sampling_param=sampling_param)
29+
30+
if __name__ == "__main__":
31+
main()

fastvideo/configs/models/dits/wanvideo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ class WanVideoArchConfig(DiTArchConfig):
9292
pos_embed_seq_len: int | None = None
9393
exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"])
9494

95+
# Causal Wan
96+
local_attn_size: int = -1 # Window size for temporal local attention (-1 indicates global attention)
97+
sink_size: int = 0 # Size of the attention sink, we keep the first `sink_size` frames unchanged when rolling the KV cache
98+
num_frames_per_block: int = 3
99+
sliding_window_num_frames: int = 21
100+
95101
def __post_init__(self):
96102
super().__post_init__()
97103
self.out_channels = self.out_channels or self.in_channels

fastvideo/configs/pipelines/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from fastvideo.configs.pipelines.stepvideo import StepVideoT2VConfig
1010
from fastvideo.configs.pipelines.wan import (FastWan2_1_T2V_480P_Config,
1111
FastWan2_2_TI2V_5B_Config,
12+
SelfForcingWanT2V480PConfig,
1213
WanI2V480PConfig, WanI2V720PConfig,
1314
WanT2V480PConfig, WanT2V720PConfig)
1415
from fastvideo.logger import init_logger
@@ -34,6 +35,7 @@
3435
"Wan-AI/Wan2.2-TI2V-5B-Diffusers": WanT2V720PConfig,
3536
"Wan-AI/Wan2.2-T2V-A14B-Diffusers": WanT2V480PConfig,
3637
"Wan-AI/Wan2.2-I2V-A14B-Diffusers": WanI2V480PConfig,
38+
"wlsaidhi/SFWan2.1-T2V-1.3B-Diffusers": SelfForcingWanT2V480PConfig,
3739
# Add other specific weight variants
3840
}
3941

fastvideo/configs/pipelines/wan.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,13 @@ class Wan2_2_T2V_A14B_Config(WanT2V480PConfig):
138138
@dataclass
139139
class Wan2_2_I2V_A14B_Config(WanT2V480PConfig):
140140
pass
141+
142+
143+
# =============================================
144+
# ============= Causal Self-Forcing =============
145+
# =============================================
146+
@dataclass
147+
class SelfForcingWanT2V480PConfig(WanT2V480PConfig):
148+
is_causal: bool = True
149+
dmd_denoising_steps: list[int] | None = field(
150+
default_factory=lambda: [1000, 750, 500, 250])

fastvideo/configs/sample/registry.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
WanI2V_14B_720P_SamplingParam,
1818
WanT2V_1_3B_SamplingParam,
1919
WanT2V_14B_SamplingParam,
20+
SelfForcingWanT2V480PConfig,
2021
)
2122
# isort: on
2223
from fastvideo.logger import init_logger
@@ -28,17 +29,29 @@
2829
SAMPLING_PARAM_REGISTRY: dict[str, Any] = {
2930
"FastVideo/FastHunyuan-diffusers": FastHunyuanSamplingParam,
3031
"hunyuanvideo-community/HunyuanVideo": HunyuanSamplingParam,
32+
"FastVideo/stepvideo-t2v-diffusers": StepVideoT2VSamplingParam,
33+
34+
# Wan2.1
3135
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers": WanT2V_1_3B_SamplingParam,
3236
"Wan-AI/Wan2.1-T2V-14B-Diffusers": WanT2V_14B_SamplingParam,
3337
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers": WanI2V_14B_480P_SamplingParam,
3438
"Wan-AI/Wan2.1-I2V-14B-720P-Diffusers": WanI2V_14B_720P_SamplingParam,
35-
"FastVideo/stepvideo-t2v-diffusers": StepVideoT2VSamplingParam,
36-
"FastVideo/FastWan2.1-T2V-1.3B-Diffusers": FastWanT2V480PConfig,
39+
40+
# Wan2.2
3741
"Wan-AI/Wan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_SamplingParam,
3842
"FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers":
3943
Wan2_2_TI2V_5B_SamplingParam,
4044
"Wan-AI/Wan2.2-T2V-A14B-Diffusers": Wan2_2_T2V_A14B_SamplingParam,
4145
"Wan-AI/Wan2.2-I2V-A14B-Diffusers": Wan2_2_I2V_A14B_SamplingParam,
46+
47+
# FastWan2.1
48+
"FastVideo/FastWan2.1-T2V-1.3B-Diffusers": FastWanT2V480PConfig,
49+
50+
# FastWan2.2
51+
"FastVideo/FastWan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_SamplingParam,
52+
53+
# Causal Self-Forcing Wan2.1
54+
"wlsaidhi/SFWan2.1-T2V-1.3B-Diffusers": SelfForcingWanT2V480PConfig,
4255
# Add other specific weight variants
4356
}
4457

fastvideo/configs/sample/wan.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,11 @@ class Wan2_2_I2V_A14B_SamplingParam(Wan2_2_Base_SamplingParam):
141141
guidance_scale_2: float = 3.5
142142
num_inference_steps: int = 40
143143
fps: int = 16
144+
145+
146+
# =============================================
147+
# ============= Causal Self-Forcing =============
148+
# =============================================
149+
@dataclass
150+
class SelfForcingWanT2V480PConfig(WanT2V_1_3B_SamplingParam):
151+
pass

fastvideo/layers/rotary_embedding.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929

3030
from fastvideo.distributed.parallel_state import get_sp_group
3131
from fastvideo.layers.custom_op import CustomOp
32+
from fastvideo.logger import init_logger
33+
34+
logger = init_logger(__name__)
3235

3336

3437
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -267,6 +270,7 @@ def get_nd_rotary_pos_embed(
267270
sp_rank: int = 0,
268271
sp_world_size: int = 1,
269272
dtype: torch.dtype = torch.float32,
273+
start_frame: int = 0,
270274
) -> tuple[torch.Tensor, torch.Tensor]:
271275
"""
272276
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
@@ -292,6 +296,9 @@ def get_nd_rotary_pos_embed(
292296
full_grid = get_meshgrid_nd(
293297
start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
294298

299+
if start_frame > 0:
300+
full_grid[0] += start_frame
301+
295302
# Shard the grid if using sequence parallelism (sp_world_size > 1)
296303
assert shard_dim < len(
297304
rope_dim_list
@@ -370,6 +377,7 @@ def get_rotary_pos_embed(
370377
interpolation_factor=1.0,
371378
shard_dim: int = 0,
372379
dtype: torch.dtype = torch.float32,
380+
start_frame: int = 0,
373381
) -> tuple[torch.Tensor, torch.Tensor]:
374382
"""
375383
Generate rotary positional embeddings for the given sizes.
@@ -413,6 +421,7 @@ def get_rotary_pos_embed(
413421
sp_rank=sp_rank,
414422
sp_world_size=sp_world_size,
415423
dtype=dtype,
424+
start_frame=start_frame,
416425
)
417426
return freqs_cos, freqs_sin
418427

0 commit comments

Comments
 (0)