Skip to content

Commit 8b89aff

Browse files
committed
Add cosmos2 i2v pipeline
1 parent 4f3e875 commit 8b89aff

32 files changed

+2582
-54
lines changed
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
from fastvideo.configs.models.dits.cosmos import CosmosVideoConfig
12
from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig
23
from fastvideo.configs.models.dits.stepvideo import StepVideoConfig
34
from fastvideo.configs.models.dits.wanvideo import WanVideoConfig
45

5-
__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig"]
6+
__all__ = [
7+
"HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig",
8+
"CosmosVideoConfig"
9+
]
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from dataclasses import dataclass, field
3+
4+
from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig
5+
6+
7+
def is_transformer_blocks(n: str, m) -> bool:
8+
return "transformer_blocks" in n and str.isdigit(n.split(".")[-1])
9+
10+
11+
@dataclass
12+
class CosmosArchConfig(DiTArchConfig):
13+
_fsdp_shard_conditions: list = field(
14+
default_factory=lambda: [is_transformer_blocks])
15+
16+
param_names_mapping: dict = field(
17+
default_factory=lambda: {
18+
r"^patch_embed\.(.*)$": r"patch_embed.\1",
19+
r"^time_embed\.time_proj\.(.*)$": r"time_embed.time_proj.\1",
20+
r"^time_embed\.t_embedder\.(.*)$": r"time_embed.t_embedder.\1",
21+
r"^time_embed\.norm\.(.*)$": r"time_embed.norm.\1",
22+
r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$":
23+
r"transformer_blocks.\1.attn1.to_q.\2",
24+
r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$":
25+
r"transformer_blocks.\1.attn1.to_k.\2",
26+
r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$":
27+
r"transformer_blocks.\1.attn1.to_v.\2",
28+
r"^transformer_blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$":
29+
r"transformer_blocks.\1.attn1.to_out.\2",
30+
r"^transformer_blocks\.(\d+)\.attn1\.norm_q\.(.*)$":
31+
r"transformer_blocks.\1.attn1.norm_q.\2",
32+
r"^transformer_blocks\.(\d+)\.attn1\.norm_k\.(.*)$":
33+
r"transformer_blocks.\1.attn1.norm_k.\2",
34+
r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$":
35+
r"transformer_blocks.\1.attn2.to_q.\2",
36+
r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$":
37+
r"transformer_blocks.\1.attn2.to_k.\2",
38+
r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$":
39+
r"transformer_blocks.\1.attn2.to_v.\2",
40+
r"^transformer_blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$":
41+
r"transformer_blocks.\1.attn2.to_out.\2",
42+
r"^transformer_blocks\.(\d+)\.attn2\.norm_q\.(.*)$":
43+
r"transformer_blocks.\1.attn2.norm_q.\2",
44+
r"^transformer_blocks\.(\d+)\.attn2\.norm_k\.(.*)$":
45+
r"transformer_blocks.\1.attn2.norm_k.\2",
46+
r"^transformer_blocks\.(\d+)\.ff\.net\.0\.proj\.(.*)$":
47+
r"transformer_blocks.\1.ff.fc_in.\2",
48+
r"^transformer_blocks\.(\d+)\.ff\.net\.2\.(.*)$":
49+
r"transformer_blocks.\1.ff.fc_out.\2",
50+
r"^norm_out\.(.*)$": r"norm_out.\1",
51+
r"^proj_out\.(.*)$": r"proj_out.\1",
52+
})
53+
54+
lora_param_names_mapping: dict = field(
55+
default_factory=lambda: {
56+
r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$":
57+
r"transformer_blocks.\1.attn1.to_q.\2",
58+
r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$":
59+
r"transformer_blocks.\1.attn1.to_k.\2",
60+
r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$":
61+
r"transformer_blocks.\1.attn1.to_v.\2",
62+
r"^transformer_blocks\.(\d+)\.attn1\.to_out\.(.*)$":
63+
r"transformer_blocks.\1.attn1.to_out.\2",
64+
r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$":
65+
r"transformer_blocks.\1.attn2.to_q.\2",
66+
r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$":
67+
r"transformer_blocks.\1.attn2.to_k.\2",
68+
r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$":
69+
r"transformer_blocks.\1.attn2.to_v.\2",
70+
r"^transformer_blocks\.(\d+)\.attn2\.to_out\.(.*)$":
71+
r"transformer_blocks.\1.attn2.to_out.\2",
72+
r"^transformer_blocks\.(\d+)\.ff\.(.*)$":
73+
r"transformer_blocks.\1.ff.\2",
74+
})
75+
76+
# Cosmos-specific config parameters based on transformer_cosmos.py
77+
in_channels: int = 16
78+
out_channels: int = 16
79+
num_attention_heads: int = 16
80+
attention_head_dim: int = 128
81+
num_layers: int = 28
82+
mlp_ratio: float = 4.0
83+
text_embed_dim: int = 1024
84+
adaln_lora_dim: int = 256
85+
max_size: tuple[int, int, int] = (128, 240, 240)
86+
patch_size: tuple[int, int, int] = (1, 2, 2)
87+
rope_scale: tuple[float, float, float] = (1.0, 3.0, 3.0)
88+
concat_padding_mask: bool = True
89+
extra_pos_embed_type: str | None = None
90+
qk_norm: str = "rms_norm"
91+
eps: float = 1e-6
92+
exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"])
93+
94+
def __post_init__(self):
95+
super().__post_init__()
96+
self.out_channels = self.out_channels or self.in_channels
97+
self.hidden_size = self.num_attention_heads * self.attention_head_dim
98+
self.num_channels_latents = self.in_channels
99+
100+
101+
@dataclass
102+
class CosmosVideoConfig(DiTConfig):
103+
arch_config: DiTArchConfig = field(default_factory=CosmosArchConfig)
104+
prefix: str = "Cosmos"

fastvideo/configs/models/encoders/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from fastvideo.configs.models.encoders.clip import (
66
CLIPTextConfig, CLIPVisionConfig, WAN2_1ControlCLIPVisionConfig)
77
from fastvideo.configs.models.encoders.llama import LlamaConfig
8-
from fastvideo.configs.models.encoders.t5 import T5Config
8+
from fastvideo.configs.models.encoders.t5 import T5Config, T5LargeConfig
99

1010
__all__ = [
1111
"EncoderConfig", "TextEncoderConfig", "ImageEncoderConfig",
1212
"BaseEncoderOutput", "CLIPTextConfig", "CLIPVisionConfig",
13-
"WAN2_1ControlCLIPVisionConfig", "LlamaConfig", "T5Config"
13+
"WAN2_1ControlCLIPVisionConfig", "LlamaConfig", "T5Config",
14+
"T5LargeConfig"
1415
]

fastvideo/configs/models/encoders/t5.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,31 @@ def __post_init__(self):
7070
}
7171

7272

73+
@dataclass
74+
class T5LargeArchConfig(T5ArchConfig):
75+
"""T5 Large architecture config with parameters for your specific model."""
76+
d_model: int = 1024
77+
d_kv: int = 128
78+
d_ff: int = 65536
79+
num_layers: int = 24
80+
num_decoder_layers: int | None = 24
81+
num_heads: int = 128
82+
decoder_start_token_id: int = 0
83+
n_positions: int = 512
84+
task_specific_params: dict | None = None
85+
86+
7387
@dataclass
7488
class T5Config(TextEncoderConfig):
7589
arch_config: TextEncoderArchConfig = field(default_factory=T5ArchConfig)
7690

7791
prefix: str = "t5"
92+
93+
94+
@dataclass
95+
class T5LargeConfig(TextEncoderConfig):
96+
"""T5 Large configuration for your specific model."""
97+
arch_config: TextEncoderArchConfig = field(
98+
default_factory=T5LargeArchConfig)
99+
100+
prefix: str = "t5"

fastvideo/configs/models/vaes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from fastvideo.configs.models.vaes.cosmosvae import CosmosVAEConfig
12
from fastvideo.configs.models.vaes.hunyuanvae import HunyuanVAEConfig
23
from fastvideo.configs.models.vaes.stepvideovae import StepVideoVAEConfig
34
from fastvideo.configs.models.vaes.wanvae import WanVAEConfig
@@ -6,4 +7,5 @@
67
"HunyuanVAEConfig",
78
"WanVAEConfig",
89
"StepVideoVAEConfig",
10+
"CosmosVAEConfig",
911
]
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from dataclasses import dataclass, field
3+
4+
import torch
5+
6+
from fastvideo.configs.models.vaes.base import VAEArchConfig, VAEConfig
7+
8+
9+
@dataclass
10+
class CosmosVAEArchConfig(VAEArchConfig):
11+
_name_or_path: str = ""
12+
base_dim: int = 96
13+
z_dim: int = 16
14+
dim_mult: tuple[int, ...] = (1, 2, 4, 4)
15+
num_res_blocks: int = 2
16+
attn_scales: tuple[float, ...] = ()
17+
temperal_downsample: tuple[bool, ...] = (False, True, True)
18+
dropout: float = 0.0
19+
decoder_base_dim: int | None = None
20+
is_residual: bool = False
21+
in_channels: int = 3
22+
out_channels: int = 3
23+
patch_size: int | None = None
24+
scale_factor_temporal: int = 4
25+
scale_factor_spatial: int = 8
26+
clip_output: bool = True
27+
latents_mean: tuple[float, ...] = (
28+
-0.7571,
29+
-0.7089,
30+
-0.9113,
31+
0.1075,
32+
-0.1745,
33+
0.9653,
34+
-0.1517,
35+
1.5508,
36+
0.4134,
37+
-0.0715,
38+
0.5517,
39+
-0.3632,
40+
-0.1922,
41+
-0.9497,
42+
0.2503,
43+
-0.2921,
44+
)
45+
latents_std: tuple[float, ...] = (
46+
2.8184,
47+
1.4541,
48+
2.3275,
49+
2.6558,
50+
1.2196,
51+
1.7708,
52+
2.6052,
53+
2.0743,
54+
3.2687,
55+
2.1526,
56+
2.8652,
57+
1.5579,
58+
1.6382,
59+
1.1253,
60+
2.8251,
61+
1.9160,
62+
)
63+
temporal_compression_ratio = 4
64+
spatial_compression_ratio = 8
65+
66+
def __post_init__(self):
67+
self.scaling_factor: torch.Tensor = 1.0 / torch.tensor(
68+
self.latents_std).view(1, self.z_dim, 1, 1, 1)
69+
self.shift_factor: torch.Tensor = torch.tensor(self.latents_mean).view(
70+
1, self.z_dim, 1, 1, 1)
71+
self.temporal_compression_ratio = self.scale_factor_temporal
72+
self.spatial_compression_ratio = self.scale_factor_spatial
73+
74+
75+
@dataclass
76+
class CosmosVAEConfig(VAEConfig):
77+
arch_config: CosmosVAEArchConfig = field(
78+
default_factory=CosmosVAEArchConfig)
79+
use_feature_cache: bool = True
80+
81+
use_tiling: bool = False
82+
use_temporal_tiling: bool = False
83+
use_parallel_tiling: bool = False
84+
85+
def __post_init__(self):
86+
self.blend_num_frames = (self.tile_sample_min_num_frames -
87+
self.tile_sample_stride_num_frames) * 2

fastvideo/configs/pipelines/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from fastvideo.configs.pipelines.base import (PipelineConfig,
22
SlidingTileAttnConfig)
3+
from fastvideo.configs.pipelines.cosmos import CosmosConfig
34
from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig
45
from fastvideo.configs.pipelines.registry import (
56
get_pipeline_config_cls_from_name)
@@ -12,5 +13,6 @@
1213
"HunyuanConfig", "FastHunyuanConfig", "PipelineConfig",
1314
"SlidingTileAttnConfig", "WanT2V480PConfig", "WanI2V480PConfig",
1415
"WanT2V720PConfig", "WanI2V720PConfig", "StepVideoT2VConfig",
15-
"SelfForcingWanT2V480PConfig", "get_pipeline_config_cls_from_name"
16+
"SelfForcingWanT2V480PConfig", "CosmosConfig",
17+
"get_pipeline_config_cls_from_name"
1618
]
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from collections.abc import Callable
3+
from dataclasses import dataclass, field
4+
5+
import torch
6+
7+
from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig
8+
from fastvideo.configs.models.dits import CosmosVideoConfig
9+
from fastvideo.configs.models.encoders import BaseEncoderOutput, T5LargeConfig
10+
from fastvideo.configs.models.vaes import CosmosVAEConfig
11+
from fastvideo.configs.pipelines.base import PipelineConfig
12+
13+
14+
def t5_large_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor:
15+
"""Postprocess T5 Large text encoder outputs for Cosmos pipeline.
16+
17+
Return raw last_hidden_state without truncation/padding.
18+
"""
19+
hidden_state = outputs.last_hidden_state
20+
21+
if hidden_state is None:
22+
raise ValueError("T5 Large outputs missing last_hidden_state")
23+
24+
nan_count = torch.isnan(hidden_state).sum()
25+
if nan_count > 0:
26+
hidden_state = hidden_state.masked_fill(torch.isnan(hidden_state), 0.0)
27+
28+
return hidden_state
29+
30+
31+
@dataclass
32+
class CosmosConfig(PipelineConfig):
33+
"""Configuration for Cosmos2 Video2World pipeline matching diffusers."""
34+
35+
dit_config: DiTConfig = field(default_factory=CosmosVideoConfig)
36+
37+
vae_config: VAEConfig = field(default_factory=CosmosVAEConfig)
38+
39+
text_encoder_configs: tuple[EncoderConfig, ...] = field(
40+
default_factory=lambda: (T5LargeConfig(), ))
41+
postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor],
42+
...] = field(default_factory=lambda:
43+
(t5_large_postprocess_text, ))
44+
45+
dit_precision: str = "bf16"
46+
vae_precision: str = "fp16"
47+
text_encoder_precisions: tuple[str, ...] = field(
48+
default_factory=lambda: ("bf16", ))
49+
50+
conditioning_strategy: str = "frame_replace"
51+
min_num_conditional_frames: int = 1
52+
max_num_conditional_frames: int = 2
53+
sigma_conditional: float = 0.0001
54+
sigma_data: float = 1.0
55+
state_ch: int = 16
56+
state_t: int = 24
57+
text_encoder_class: str = "T5"
58+
59+
embedded_cfg_scale: int = 6
60+
flow_shift: float = 1.0
61+
62+
def __post_init__(self):
63+
self.vae_config.load_encoder = True
64+
self.vae_config.load_decoder = True
65+
66+
self._vae_latent_dim = 16

fastvideo/configs/pipelines/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections.abc import Callable
66

77
from fastvideo.configs.pipelines.base import PipelineConfig
8+
from fastvideo.configs.pipelines.cosmos import CosmosConfig
89
from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig
910
from fastvideo.configs.pipelines.stepvideo import StepVideoT2VConfig
1011

@@ -40,6 +41,7 @@
4041
"Wan-AI/Wan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_Config,
4142
"Wan-AI/Wan2.2-T2V-A14B-Diffusers": Wan2_2_T2V_A14B_Config,
4243
"Wan-AI/Wan2.2-I2V-A14B-Diffusers": Wan2_2_I2V_A14B_Config,
44+
"nvidia/Cosmos-Predict2-2B-Video2World": CosmosConfig,
4345
# Add other specific weight variants
4446
}
4547

@@ -51,6 +53,7 @@
5153
"wandmdpipeline": lambda id: "wandmdpipeline" in id.lower(),
5254
"wancausaldmdpipeline": lambda id: "wancausaldmdpipeline" in id.lower(),
5355
"stepvideo": lambda id: "stepvideo" in id.lower(),
56+
"cosmos": lambda id: "cosmos" in id.lower(),
5457
# Add other pipeline architecture detectors
5558
}
5659

fastvideo/configs/sample/cosmos.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from dataclasses import dataclass
3+
4+
from fastvideo.configs.sample.base import SamplingParam
5+
6+
7+
@dataclass
8+
class Cosmos_Predict2_2B_Video2World_SamplingParam(SamplingParam):
9+
# Video parameters
10+
height: int = 704
11+
width: int = 1280
12+
num_frames: int = 93
13+
fps: int = 16
14+
15+
# Denoising stage
16+
guidance_scale: float = 7.0
17+
negative_prompt: str = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
18+
num_inference_steps: int = 35

0 commit comments

Comments
 (0)