Skip to content

Commit 543fea8

Browse files
authored
[Feature] Add Cosmos2 i2v pipeline (#837)
1 parent bdec816 commit 543fea8

File tree

30 files changed

+2415
-58
lines changed

30 files changed

+2415
-58
lines changed

.buildkite/scripts/pr_test.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ log "Setting up Modal authentication from Buildkite secrets..."
3131
MODAL_TOKEN_ID=$(buildkite-agent secret get modal_token_id)
3232
MODAL_TOKEN_SECRET=$(buildkite-agent secret get modal_token_secret)
3333

34+
# Retrieve other secrets
3435
WANDB_API_KEY=$(buildkite-agent secret get wandb_api_key)
35-
36-
WANDB_API_KEY=$(buildkite-agent secret get wandb_api_key)
36+
HF_API_KEY=$(buildkite-agent secret get hf_api_key)
3737

3838
if [ -n "$MODAL_TOKEN_ID" ] && [ -n "$MODAL_TOKEN_SECRET" ]; then
3939
log "Retrieved Modal credentials from Buildkite secrets"
@@ -63,15 +63,15 @@ MODAL_ENV="BUILDKITE_REPO=$BUILDKITE_REPO BUILDKITE_COMMIT=$BUILDKITE_COMMIT BUI
6363
case "$TEST_TYPE" in
6464
"encoder")
6565
log "Running encoder tests..."
66-
MODAL_COMMAND="$MODAL_ENV python3 -m modal run $MODAL_TEST_FILE::run_encoder_tests"
66+
MODAL_COMMAND="$MODAL_ENV HF_API_KEY=$HF_API_KEY python3 -m modal run $MODAL_TEST_FILE::run_encoder_tests"
6767
;;
6868
"vae")
6969
log "Running VAE tests..."
70-
MODAL_COMMAND="$MODAL_ENV python3 -m modal run $MODAL_TEST_FILE::run_vae_tests"
70+
MODAL_COMMAND="$MODAL_ENV HF_API_KEY=$HF_API_KEY python3 -m modal run $MODAL_TEST_FILE::run_vae_tests"
7171
;;
7272
"transformer")
7373
log "Running transformer tests..."
74-
MODAL_COMMAND="$MODAL_ENV python3 -m modal run $MODAL_TEST_FILE::run_transformer_tests"
74+
MODAL_COMMAND="$MODAL_ENV HF_API_KEY=$HF_API_KEY python3 -m modal run $MODAL_TEST_FILE::run_transformer_tests"
7575
;;
7676
"ssim")
7777
log "Running SSIM tests..."
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
from fastvideo.configs.models.dits.cosmos import CosmosVideoConfig
12
from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig
23
from fastvideo.configs.models.dits.stepvideo import StepVideoConfig
34
from fastvideo.configs.models.dits.wanvideo import WanVideoConfig
45

5-
__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig"]
6+
__all__ = [
7+
"HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig",
8+
"CosmosVideoConfig"
9+
]
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from dataclasses import dataclass, field
3+
4+
from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig
5+
6+
7+
def is_transformer_blocks(n: str, m) -> bool:
8+
return "transformer_blocks" in n and str.isdigit(n.split(".")[-1])
9+
10+
11+
@dataclass
12+
class CosmosArchConfig(DiTArchConfig):
13+
_fsdp_shard_conditions: list = field(
14+
default_factory=lambda: [is_transformer_blocks])
15+
16+
param_names_mapping: dict = field(
17+
default_factory=lambda: {
18+
r"^patch_embed\.(.*)$": r"patch_embed.\1",
19+
r"^time_embed\.time_proj\.(.*)$": r"time_embed.time_proj.\1",
20+
r"^time_embed\.t_embedder\.(.*)$": r"time_embed.t_embedder.\1",
21+
r"^time_embed\.norm\.(.*)$": r"time_embed.norm.\1",
22+
r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$":
23+
r"transformer_blocks.\1.attn1.to_q.\2",
24+
r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$":
25+
r"transformer_blocks.\1.attn1.to_k.\2",
26+
r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$":
27+
r"transformer_blocks.\1.attn1.to_v.\2",
28+
r"^transformer_blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$":
29+
r"transformer_blocks.\1.attn1.to_out.\2",
30+
r"^transformer_blocks\.(\d+)\.attn1\.norm_q\.(.*)$":
31+
r"transformer_blocks.\1.attn1.norm_q.\2",
32+
r"^transformer_blocks\.(\d+)\.attn1\.norm_k\.(.*)$":
33+
r"transformer_blocks.\1.attn1.norm_k.\2",
34+
r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$":
35+
r"transformer_blocks.\1.attn2.to_q.\2",
36+
r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$":
37+
r"transformer_blocks.\1.attn2.to_k.\2",
38+
r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$":
39+
r"transformer_blocks.\1.attn2.to_v.\2",
40+
r"^transformer_blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$":
41+
r"transformer_blocks.\1.attn2.to_out.\2",
42+
r"^transformer_blocks\.(\d+)\.attn2\.norm_q\.(.*)$":
43+
r"transformer_blocks.\1.attn2.norm_q.\2",
44+
r"^transformer_blocks\.(\d+)\.attn2\.norm_k\.(.*)$":
45+
r"transformer_blocks.\1.attn2.norm_k.\2",
46+
r"^transformer_blocks\.(\d+)\.ff\.net\.0\.proj\.(.*)$":
47+
r"transformer_blocks.\1.ff.fc_in.\2",
48+
r"^transformer_blocks\.(\d+)\.ff\.net\.2\.(.*)$":
49+
r"transformer_blocks.\1.ff.fc_out.\2",
50+
r"^norm_out\.(.*)$": r"norm_out.\1",
51+
r"^proj_out\.(.*)$": r"proj_out.\1",
52+
})
53+
54+
lora_param_names_mapping: dict = field(
55+
default_factory=lambda: {
56+
r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$":
57+
r"transformer_blocks.\1.attn1.to_q.\2",
58+
r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$":
59+
r"transformer_blocks.\1.attn1.to_k.\2",
60+
r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$":
61+
r"transformer_blocks.\1.attn1.to_v.\2",
62+
r"^transformer_blocks\.(\d+)\.attn1\.to_out\.(.*)$":
63+
r"transformer_blocks.\1.attn1.to_out.\2",
64+
r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$":
65+
r"transformer_blocks.\1.attn2.to_q.\2",
66+
r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$":
67+
r"transformer_blocks.\1.attn2.to_k.\2",
68+
r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$":
69+
r"transformer_blocks.\1.attn2.to_v.\2",
70+
r"^transformer_blocks\.(\d+)\.attn2\.to_out\.(.*)$":
71+
r"transformer_blocks.\1.attn2.to_out.\2",
72+
r"^transformer_blocks\.(\d+)\.ff\.(.*)$":
73+
r"transformer_blocks.\1.ff.\2",
74+
})
75+
76+
# Cosmos-specific config parameters based on transformer_cosmos.py
77+
in_channels: int = 16
78+
out_channels: int = 16
79+
num_attention_heads: int = 16
80+
attention_head_dim: int = 128
81+
num_layers: int = 28
82+
mlp_ratio: float = 4.0
83+
text_embed_dim: int = 1024
84+
adaln_lora_dim: int = 256
85+
max_size: tuple[int, int, int] = (128, 240, 240)
86+
patch_size: tuple[int, int, int] = (1, 2, 2)
87+
rope_scale: tuple[float, float, float] = (1.0, 3.0, 3.0)
88+
concat_padding_mask: bool = True
89+
extra_pos_embed_type: str | None = None
90+
qk_norm: str = "rms_norm"
91+
eps: float = 1e-6
92+
exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"])
93+
94+
def __post_init__(self):
95+
super().__post_init__()
96+
self.out_channels = self.out_channels or self.in_channels
97+
self.hidden_size = self.num_attention_heads * self.attention_head_dim
98+
self.num_channels_latents = self.in_channels
99+
100+
101+
@dataclass
102+
class CosmosVideoConfig(DiTConfig):
103+
arch_config: DiTArchConfig = field(default_factory=CosmosArchConfig)
104+
prefix: str = "Cosmos"

fastvideo/configs/models/encoders/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from fastvideo.configs.models.encoders.clip import (
66
CLIPTextConfig, CLIPVisionConfig, WAN2_1ControlCLIPVisionConfig)
77
from fastvideo.configs.models.encoders.llama import LlamaConfig
8-
from fastvideo.configs.models.encoders.t5 import T5Config
8+
from fastvideo.configs.models.encoders.t5 import T5Config, T5LargeConfig
99

1010
__all__ = [
1111
"EncoderConfig", "TextEncoderConfig", "ImageEncoderConfig",
1212
"BaseEncoderOutput", "CLIPTextConfig", "CLIPVisionConfig",
13-
"WAN2_1ControlCLIPVisionConfig", "LlamaConfig", "T5Config"
13+
"WAN2_1ControlCLIPVisionConfig", "LlamaConfig", "T5Config", "T5LargeConfig"
1414
]

fastvideo/configs/models/encoders/t5.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,31 @@ def __post_init__(self):
7070
}
7171

7272

73+
@dataclass
74+
class T5LargeArchConfig(T5ArchConfig):
75+
"""T5 Large architecture config with parameters for your specific model."""
76+
d_model: int = 1024
77+
d_kv: int = 128
78+
d_ff: int = 65536
79+
num_layers: int = 24
80+
num_decoder_layers: int | None = 24
81+
num_heads: int = 128
82+
decoder_start_token_id: int = 0
83+
n_positions: int = 512
84+
task_specific_params: dict | None = None
85+
86+
7387
@dataclass
7488
class T5Config(TextEncoderConfig):
7589
arch_config: TextEncoderArchConfig = field(default_factory=T5ArchConfig)
7690

7791
prefix: str = "t5"
92+
93+
94+
@dataclass
95+
class T5LargeConfig(TextEncoderConfig):
96+
"""T5 Large configuration for your specific model."""
97+
arch_config: TextEncoderArchConfig = field(
98+
default_factory=T5LargeArchConfig)
99+
100+
prefix: str = "t5"

fastvideo/configs/models/vaes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from fastvideo.configs.models.vaes.cosmosvae import CosmosVAEConfig
12
from fastvideo.configs.models.vaes.hunyuanvae import HunyuanVAEConfig
23
from fastvideo.configs.models.vaes.stepvideovae import StepVideoVAEConfig
34
from fastvideo.configs.models.vaes.wanvae import WanVAEConfig
@@ -6,4 +7,5 @@
67
"HunyuanVAEConfig",
78
"WanVAEConfig",
89
"StepVideoVAEConfig",
10+
"CosmosVAEConfig",
911
]
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from dataclasses import dataclass, field
3+
4+
import torch
5+
6+
from fastvideo.configs.models.vaes.base import VAEArchConfig, VAEConfig
7+
8+
9+
@dataclass
10+
class CosmosVAEArchConfig(VAEArchConfig):
11+
_name_or_path: str = ""
12+
base_dim: int = 96
13+
z_dim: int = 16
14+
dim_mult: tuple[int, ...] = (1, 2, 4, 4)
15+
num_res_blocks: int = 2
16+
attn_scales: tuple[float, ...] = ()
17+
temperal_downsample: tuple[bool, ...] = (False, True, True)
18+
dropout: float = 0.0
19+
decoder_base_dim: int | None = None
20+
is_residual: bool = False
21+
in_channels: int = 3
22+
out_channels: int = 3
23+
patch_size: int | None = None
24+
scale_factor_temporal: int = 4
25+
scale_factor_spatial: int = 8
26+
clip_output: bool = True
27+
latents_mean: tuple[float, ...] = (
28+
-0.7571,
29+
-0.7089,
30+
-0.9113,
31+
0.1075,
32+
-0.1745,
33+
0.9653,
34+
-0.1517,
35+
1.5508,
36+
0.4134,
37+
-0.0715,
38+
0.5517,
39+
-0.3632,
40+
-0.1922,
41+
-0.9497,
42+
0.2503,
43+
-0.2921,
44+
)
45+
latents_std: tuple[float, ...] = (
46+
2.8184,
47+
1.4541,
48+
2.3275,
49+
2.6558,
50+
1.2196,
51+
1.7708,
52+
2.6052,
53+
2.0743,
54+
3.2687,
55+
2.1526,
56+
2.8652,
57+
1.5579,
58+
1.6382,
59+
1.1253,
60+
2.8251,
61+
1.9160,
62+
)
63+
temporal_compression_ratio = 4
64+
spatial_compression_ratio = 8
65+
66+
def __post_init__(self):
67+
self.scaling_factor: torch.Tensor = 1.0 / torch.tensor(
68+
self.latents_std).view(1, self.z_dim, 1, 1, 1)
69+
self.shift_factor: torch.Tensor = torch.tensor(self.latents_mean).view(
70+
1, self.z_dim, 1, 1, 1)
71+
self.temporal_compression_ratio = self.scale_factor_temporal
72+
self.spatial_compression_ratio = self.scale_factor_spatial
73+
74+
75+
@dataclass
76+
class CosmosVAEConfig(VAEConfig):
77+
arch_config: CosmosVAEArchConfig = field(
78+
default_factory=CosmosVAEArchConfig)
79+
use_feature_cache: bool = True
80+
81+
use_tiling: bool = False
82+
use_temporal_tiling: bool = False
83+
use_parallel_tiling: bool = False
84+
85+
def __post_init__(self):
86+
self.blend_num_frames = (self.tile_sample_min_num_frames -
87+
self.tile_sample_stride_num_frames) * 2

fastvideo/configs/pipelines/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from fastvideo.configs.pipelines.base import (PipelineConfig,
22
SlidingTileAttnConfig)
3+
from fastvideo.configs.pipelines.cosmos import CosmosConfig
34
from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig
45
from fastvideo.configs.pipelines.registry import (
56
get_pipeline_config_cls_from_name)
@@ -12,5 +13,6 @@
1213
"HunyuanConfig", "FastHunyuanConfig", "PipelineConfig",
1314
"SlidingTileAttnConfig", "WanT2V480PConfig", "WanI2V480PConfig",
1415
"WanT2V720PConfig", "WanI2V720PConfig", "StepVideoT2VConfig",
15-
"SelfForcingWanT2V480PConfig", "get_pipeline_config_cls_from_name"
16+
"SelfForcingWanT2V480PConfig", "CosmosConfig",
17+
"get_pipeline_config_cls_from_name"
1618
]
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from collections.abc import Callable
3+
from dataclasses import dataclass, field
4+
5+
import torch
6+
7+
from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig
8+
from fastvideo.configs.models.dits import CosmosVideoConfig
9+
from fastvideo.configs.models.encoders import BaseEncoderOutput, T5LargeConfig
10+
from fastvideo.configs.models.vaes import CosmosVAEConfig
11+
from fastvideo.configs.pipelines.base import PipelineConfig
12+
13+
14+
def t5_large_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor:
15+
"""Postprocess T5 Large text encoder outputs for Cosmos pipeline.
16+
17+
Return raw last_hidden_state without truncation/padding.
18+
"""
19+
hidden_state = outputs.last_hidden_state
20+
21+
if hidden_state is None:
22+
raise ValueError("T5 Large outputs missing last_hidden_state")
23+
24+
nan_count = torch.isnan(hidden_state).sum()
25+
if nan_count > 0:
26+
hidden_state = hidden_state.masked_fill(torch.isnan(hidden_state), 0.0)
27+
28+
# Zero out embeddings beyond actual sequence length
29+
if outputs.attention_mask is not None:
30+
attention_mask = outputs.attention_mask
31+
lengths = attention_mask.sum(dim=1).cpu()
32+
for i, length in enumerate(lengths):
33+
hidden_state[i, length:] = 0
34+
35+
return hidden_state
36+
37+
38+
@dataclass
39+
class CosmosConfig(PipelineConfig):
40+
"""Configuration for Cosmos2 Video2World pipeline matching diffusers."""
41+
42+
dit_config: DiTConfig = field(default_factory=CosmosVideoConfig)
43+
44+
vae_config: VAEConfig = field(default_factory=CosmosVAEConfig)
45+
46+
text_encoder_configs: tuple[EncoderConfig, ...] = field(
47+
default_factory=lambda: (T5LargeConfig(), ))
48+
postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor],
49+
...] = field(default_factory=lambda:
50+
(t5_large_postprocess_text, ))
51+
52+
dit_precision: str = "bf16"
53+
vae_precision: str = "fp16"
54+
text_encoder_precisions: tuple[str, ...] = field(
55+
default_factory=lambda: ("bf16", ))
56+
57+
conditioning_strategy: str = "frame_replace"
58+
min_num_conditional_frames: int = 1
59+
max_num_conditional_frames: int = 2
60+
sigma_conditional: float = 0.0001
61+
sigma_data: float = 1.0
62+
state_ch: int = 16
63+
state_t: int = 24
64+
text_encoder_class: str = "T5"
65+
66+
embedded_cfg_scale: int = 6
67+
flow_shift: float = 1.0
68+
69+
def __post_init__(self):
70+
self.vae_config.load_encoder = True
71+
self.vae_config.load_decoder = True
72+
73+
self._vae_latent_dim = 16

0 commit comments

Comments
 (0)