Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .buildkite/scripts/pr_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ log "Setting up Modal authentication from Buildkite secrets..."
MODAL_TOKEN_ID=$(buildkite-agent secret get modal_token_id)
MODAL_TOKEN_SECRET=$(buildkite-agent secret get modal_token_secret)

# Retrieve other secrets
WANDB_API_KEY=$(buildkite-agent secret get wandb_api_key)

WANDB_API_KEY=$(buildkite-agent secret get wandb_api_key)
HF_API_KEY=$(buildkite-agent secret get hf_api_key)

if [ -n "$MODAL_TOKEN_ID" ] && [ -n "$MODAL_TOKEN_SECRET" ]; then
log "Retrieved Modal credentials from Buildkite secrets"
Expand Down Expand Up @@ -63,15 +63,15 @@ MODAL_ENV="BUILDKITE_REPO=$BUILDKITE_REPO BUILDKITE_COMMIT=$BUILDKITE_COMMIT BUI
case "$TEST_TYPE" in
"encoder")
log "Running encoder tests..."
MODAL_COMMAND="$MODAL_ENV python3 -m modal run $MODAL_TEST_FILE::run_encoder_tests"
MODAL_COMMAND="$MODAL_ENV HF_API_KEY=$HF_API_KEY python3 -m modal run $MODAL_TEST_FILE::run_encoder_tests"
;;
"vae")
log "Running VAE tests..."
MODAL_COMMAND="$MODAL_ENV python3 -m modal run $MODAL_TEST_FILE::run_vae_tests"
MODAL_COMMAND="$MODAL_ENV HF_API_KEY=$HF_API_KEY python3 -m modal run $MODAL_TEST_FILE::run_vae_tests"
;;
"transformer")
log "Running transformer tests..."
MODAL_COMMAND="$MODAL_ENV python3 -m modal run $MODAL_TEST_FILE::run_transformer_tests"
MODAL_COMMAND="$MODAL_ENV HF_API_KEY=$HF_API_KEY python3 -m modal run $MODAL_TEST_FILE::run_transformer_tests"
;;
"ssim")
log "Running SSIM tests..."
Expand Down
6 changes: 5 additions & 1 deletion fastvideo/configs/models/dits/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from fastvideo.configs.models.dits.cosmos import CosmosVideoConfig
from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig
from fastvideo.configs.models.dits.stepvideo import StepVideoConfig
from fastvideo.configs.models.dits.wanvideo import WanVideoConfig

__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig"]
__all__ = [
"HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig",
"CosmosVideoConfig"
]
104 changes: 104 additions & 0 deletions fastvideo/configs/models/dits/cosmos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field

from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig


def is_transformer_blocks(n: str, m) -> bool:
return "transformer_blocks" in n and str.isdigit(n.split(".")[-1])


@dataclass
class CosmosArchConfig(DiTArchConfig):
_fsdp_shard_conditions: list = field(
default_factory=lambda: [is_transformer_blocks])

param_names_mapping: dict = field(
default_factory=lambda: {
r"^patch_embed\.(.*)$": r"patch_embed.\1",
r"^time_embed\.time_proj\.(.*)$": r"time_embed.time_proj.\1",
r"^time_embed\.t_embedder\.(.*)$": r"time_embed.t_embedder.\1",
r"^time_embed\.norm\.(.*)$": r"time_embed.norm.\1",
r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$":
r"transformer_blocks.\1.attn1.to_q.\2",
r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$":
r"transformer_blocks.\1.attn1.to_k.\2",
r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$":
r"transformer_blocks.\1.attn1.to_v.\2",
r"^transformer_blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$":
r"transformer_blocks.\1.attn1.to_out.\2",
r"^transformer_blocks\.(\d+)\.attn1\.norm_q\.(.*)$":
r"transformer_blocks.\1.attn1.norm_q.\2",
r"^transformer_blocks\.(\d+)\.attn1\.norm_k\.(.*)$":
r"transformer_blocks.\1.attn1.norm_k.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$":
r"transformer_blocks.\1.attn2.to_q.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$":
r"transformer_blocks.\1.attn2.to_k.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$":
r"transformer_blocks.\1.attn2.to_v.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$":
r"transformer_blocks.\1.attn2.to_out.\2",
r"^transformer_blocks\.(\d+)\.attn2\.norm_q\.(.*)$":
r"transformer_blocks.\1.attn2.norm_q.\2",
r"^transformer_blocks\.(\d+)\.attn2\.norm_k\.(.*)$":
r"transformer_blocks.\1.attn2.norm_k.\2",
r"^transformer_blocks\.(\d+)\.ff\.net\.0\.proj\.(.*)$":
r"transformer_blocks.\1.ff.fc_in.\2",
r"^transformer_blocks\.(\d+)\.ff\.net\.2\.(.*)$":
r"transformer_blocks.\1.ff.fc_out.\2",
r"^norm_out\.(.*)$": r"norm_out.\1",
r"^proj_out\.(.*)$": r"proj_out.\1",
})

lora_param_names_mapping: dict = field(
default_factory=lambda: {
r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$":
r"transformer_blocks.\1.attn1.to_q.\2",
r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$":
r"transformer_blocks.\1.attn1.to_k.\2",
r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$":
r"transformer_blocks.\1.attn1.to_v.\2",
r"^transformer_blocks\.(\d+)\.attn1\.to_out\.(.*)$":
r"transformer_blocks.\1.attn1.to_out.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$":
r"transformer_blocks.\1.attn2.to_q.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$":
r"transformer_blocks.\1.attn2.to_k.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$":
r"transformer_blocks.\1.attn2.to_v.\2",
r"^transformer_blocks\.(\d+)\.attn2\.to_out\.(.*)$":
r"transformer_blocks.\1.attn2.to_out.\2",
r"^transformer_blocks\.(\d+)\.ff\.(.*)$":
r"transformer_blocks.\1.ff.\2",
})

# Cosmos-specific config parameters based on transformer_cosmos.py
in_channels: int = 16
out_channels: int = 16
num_attention_heads: int = 16
attention_head_dim: int = 128
num_layers: int = 28
mlp_ratio: float = 4.0
text_embed_dim: int = 1024
adaln_lora_dim: int = 256
max_size: tuple[int, int, int] = (128, 240, 240)
patch_size: tuple[int, int, int] = (1, 2, 2)
rope_scale: tuple[float, float, float] = (1.0, 3.0, 3.0)
concat_padding_mask: bool = True
extra_pos_embed_type: str | None = None
qk_norm: str = "rms_norm"
eps: float = 1e-6
exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"])

def __post_init__(self):
super().__post_init__()
self.out_channels = self.out_channels or self.in_channels
self.hidden_size = self.num_attention_heads * self.attention_head_dim
self.num_channels_latents = self.in_channels


@dataclass
class CosmosVideoConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=CosmosArchConfig)
prefix: str = "Cosmos"
4 changes: 2 additions & 2 deletions fastvideo/configs/models/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from fastvideo.configs.models.encoders.clip import (
CLIPTextConfig, CLIPVisionConfig, WAN2_1ControlCLIPVisionConfig)
from fastvideo.configs.models.encoders.llama import LlamaConfig
from fastvideo.configs.models.encoders.t5 import T5Config
from fastvideo.configs.models.encoders.t5 import T5Config, T5LargeConfig

__all__ = [
"EncoderConfig", "TextEncoderConfig", "ImageEncoderConfig",
"BaseEncoderOutput", "CLIPTextConfig", "CLIPVisionConfig",
"WAN2_1ControlCLIPVisionConfig", "LlamaConfig", "T5Config"
"WAN2_1ControlCLIPVisionConfig", "LlamaConfig", "T5Config", "T5LargeConfig"
]
23 changes: 23 additions & 0 deletions fastvideo/configs/models/encoders/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,31 @@ def __post_init__(self):
}


@dataclass
class T5LargeArchConfig(T5ArchConfig):
"""T5 Large architecture config with parameters for your specific model."""
d_model: int = 1024
d_kv: int = 128
d_ff: int = 65536
num_layers: int = 24
num_decoder_layers: int | None = 24
num_heads: int = 128
decoder_start_token_id: int = 0
n_positions: int = 512
task_specific_params: dict | None = None


@dataclass
class T5Config(TextEncoderConfig):
arch_config: TextEncoderArchConfig = field(default_factory=T5ArchConfig)

prefix: str = "t5"


@dataclass
class T5LargeConfig(TextEncoderConfig):
"""T5 Large configuration for your specific model."""
arch_config: TextEncoderArchConfig = field(
default_factory=T5LargeArchConfig)

prefix: str = "t5"
2 changes: 2 additions & 0 deletions fastvideo/configs/models/vaes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from fastvideo.configs.models.vaes.cosmosvae import CosmosVAEConfig
from fastvideo.configs.models.vaes.hunyuanvae import HunyuanVAEConfig
from fastvideo.configs.models.vaes.stepvideovae import StepVideoVAEConfig
from fastvideo.configs.models.vaes.wanvae import WanVAEConfig
Expand All @@ -6,4 +7,5 @@
"HunyuanVAEConfig",
"WanVAEConfig",
"StepVideoVAEConfig",
"CosmosVAEConfig",
]
87 changes: 87 additions & 0 deletions fastvideo/configs/models/vaes/cosmosvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field

import torch

from fastvideo.configs.models.vaes.base import VAEArchConfig, VAEConfig


@dataclass
class CosmosVAEArchConfig(VAEArchConfig):
_name_or_path: str = ""
base_dim: int = 96
z_dim: int = 16
dim_mult: tuple[int, ...] = (1, 2, 4, 4)
num_res_blocks: int = 2
attn_scales: tuple[float, ...] = ()
temperal_downsample: tuple[bool, ...] = (False, True, True)
dropout: float = 0.0
decoder_base_dim: int | None = None
is_residual: bool = False
in_channels: int = 3
out_channels: int = 3
patch_size: int | None = None
scale_factor_temporal: int = 4
scale_factor_spatial: int = 8
clip_output: bool = True
latents_mean: tuple[float, ...] = (
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
)
latents_std: tuple[float, ...] = (
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
)
temporal_compression_ratio = 4
spatial_compression_ratio = 8

def __post_init__(self):
self.scaling_factor: torch.Tensor = 1.0 / torch.tensor(
self.latents_std).view(1, self.z_dim, 1, 1, 1)
self.shift_factor: torch.Tensor = torch.tensor(self.latents_mean).view(
1, self.z_dim, 1, 1, 1)
self.temporal_compression_ratio = self.scale_factor_temporal
self.spatial_compression_ratio = self.scale_factor_spatial


@dataclass
class CosmosVAEConfig(VAEConfig):
arch_config: CosmosVAEArchConfig = field(
default_factory=CosmosVAEArchConfig)
use_feature_cache: bool = True

use_tiling: bool = False
use_temporal_tiling: bool = False
use_parallel_tiling: bool = False

def __post_init__(self):
self.blend_num_frames = (self.tile_sample_min_num_frames -
self.tile_sample_stride_num_frames) * 2
4 changes: 3 additions & 1 deletion fastvideo/configs/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastvideo.configs.pipelines.base import (PipelineConfig,
SlidingTileAttnConfig)
from fastvideo.configs.pipelines.cosmos import CosmosConfig
from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig
from fastvideo.configs.pipelines.registry import (
get_pipeline_config_cls_from_name)
Expand All @@ -12,5 +13,6 @@
"HunyuanConfig", "FastHunyuanConfig", "PipelineConfig",
"SlidingTileAttnConfig", "WanT2V480PConfig", "WanI2V480PConfig",
"WanT2V720PConfig", "WanI2V720PConfig", "StepVideoT2VConfig",
"SelfForcingWanT2V480PConfig", "get_pipeline_config_cls_from_name"
"SelfForcingWanT2V480PConfig", "CosmosConfig",
"get_pipeline_config_cls_from_name"
]
73 changes: 73 additions & 0 deletions fastvideo/configs/pipelines/cosmos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Callable
from dataclasses import dataclass, field

import torch

from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig
from fastvideo.configs.models.dits import CosmosVideoConfig
from fastvideo.configs.models.encoders import BaseEncoderOutput, T5LargeConfig
from fastvideo.configs.models.vaes import CosmosVAEConfig
from fastvideo.configs.pipelines.base import PipelineConfig


def t5_large_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor:
"""Postprocess T5 Large text encoder outputs for Cosmos pipeline.

Return raw last_hidden_state without truncation/padding.
"""
hidden_state = outputs.last_hidden_state

if hidden_state is None:
raise ValueError("T5 Large outputs missing last_hidden_state")

nan_count = torch.isnan(hidden_state).sum()
if nan_count > 0:
hidden_state = hidden_state.masked_fill(torch.isnan(hidden_state), 0.0)

# Zero out embeddings beyond actual sequence length
if outputs.attention_mask is not None:
attention_mask = outputs.attention_mask
lengths = attention_mask.sum(dim=1).cpu()
for i, length in enumerate(lengths):
hidden_state[i, length:] = 0

return hidden_state


@dataclass
class CosmosConfig(PipelineConfig):
"""Configuration for Cosmos2 Video2World pipeline matching diffusers."""

dit_config: DiTConfig = field(default_factory=CosmosVideoConfig)

vae_config: VAEConfig = field(default_factory=CosmosVAEConfig)

text_encoder_configs: tuple[EncoderConfig, ...] = field(
default_factory=lambda: (T5LargeConfig(), ))
postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor],
...] = field(default_factory=lambda:
(t5_large_postprocess_text, ))

dit_precision: str = "bf16"
vae_precision: str = "fp16"
text_encoder_precisions: tuple[str, ...] = field(
default_factory=lambda: ("bf16", ))

conditioning_strategy: str = "frame_replace"
min_num_conditional_frames: int = 1
max_num_conditional_frames: int = 2
sigma_conditional: float = 0.0001
sigma_data: float = 1.0
state_ch: int = 16
state_t: int = 24
text_encoder_class: str = "T5"

embedded_cfg_scale: int = 6
flow_shift: float = 1.0

def __post_init__(self):
self.vae_config.load_encoder = True
self.vae_config.load_decoder = True

self._vae_latent_dim = 16
Loading