|  | 
|  | 1 | +# SPDX-License-Identifier: Apache-2.0 | 
|  | 2 | +from dataclasses import dataclass, field | 
|  | 3 | + | 
|  | 4 | +from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig | 
|  | 5 | + | 
|  | 6 | + | 
|  | 7 | +def is_transformer_blocks(n: str, m) -> bool: | 
|  | 8 | +    return "transformer_blocks" in n and str.isdigit(n.split(".")[-1]) | 
|  | 9 | + | 
|  | 10 | + | 
|  | 11 | +@dataclass | 
|  | 12 | +class CosmosArchConfig(DiTArchConfig): | 
|  | 13 | +    _fsdp_shard_conditions: list = field( | 
|  | 14 | +        default_factory=lambda: [is_transformer_blocks]) | 
|  | 15 | + | 
|  | 16 | +    param_names_mapping: dict = field( | 
|  | 17 | +        default_factory=lambda: { | 
|  | 18 | +            r"^patch_embed\.(.*)$": r"patch_embed.\1", | 
|  | 19 | +            r"^time_embed\.time_proj\.(.*)$": r"time_embed.time_proj.\1", | 
|  | 20 | +            r"^time_embed\.t_embedder\.(.*)$": r"time_embed.t_embedder.\1", | 
|  | 21 | +            r"^time_embed\.norm\.(.*)$": r"time_embed.norm.\1", | 
|  | 22 | +            r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$": | 
|  | 23 | +            r"transformer_blocks.\1.attn1.to_q.\2", | 
|  | 24 | +            r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$": | 
|  | 25 | +            r"transformer_blocks.\1.attn1.to_k.\2", | 
|  | 26 | +            r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$": | 
|  | 27 | +            r"transformer_blocks.\1.attn1.to_v.\2", | 
|  | 28 | +            r"^transformer_blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$": | 
|  | 29 | +            r"transformer_blocks.\1.attn1.to_out.\2", | 
|  | 30 | +            r"^transformer_blocks\.(\d+)\.attn1\.norm_q\.(.*)$": | 
|  | 31 | +            r"transformer_blocks.\1.attn1.norm_q.\2", | 
|  | 32 | +            r"^transformer_blocks\.(\d+)\.attn1\.norm_k\.(.*)$": | 
|  | 33 | +            r"transformer_blocks.\1.attn1.norm_k.\2", | 
|  | 34 | +            r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$": | 
|  | 35 | +            r"transformer_blocks.\1.attn2.to_q.\2", | 
|  | 36 | +            r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$": | 
|  | 37 | +            r"transformer_blocks.\1.attn2.to_k.\2", | 
|  | 38 | +            r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$": | 
|  | 39 | +            r"transformer_blocks.\1.attn2.to_v.\2", | 
|  | 40 | +            r"^transformer_blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$": | 
|  | 41 | +            r"transformer_blocks.\1.attn2.to_out.\2", | 
|  | 42 | +            r"^transformer_blocks\.(\d+)\.attn2\.norm_q\.(.*)$": | 
|  | 43 | +            r"transformer_blocks.\1.attn2.norm_q.\2", | 
|  | 44 | +            r"^transformer_blocks\.(\d+)\.attn2\.norm_k\.(.*)$": | 
|  | 45 | +            r"transformer_blocks.\1.attn2.norm_k.\2", | 
|  | 46 | +            r"^transformer_blocks\.(\d+)\.ff\.net\.0\.proj\.(.*)$": | 
|  | 47 | +            r"transformer_blocks.\1.ff.fc_in.\2", | 
|  | 48 | +            r"^transformer_blocks\.(\d+)\.ff\.net\.2\.(.*)$": | 
|  | 49 | +            r"transformer_blocks.\1.ff.fc_out.\2", | 
|  | 50 | +            r"^norm_out\.(.*)$": r"norm_out.\1", | 
|  | 51 | +            r"^proj_out\.(.*)$": r"proj_out.\1", | 
|  | 52 | +        }) | 
|  | 53 | + | 
|  | 54 | +    lora_param_names_mapping: dict = field( | 
|  | 55 | +        default_factory=lambda: { | 
|  | 56 | +            r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$": | 
|  | 57 | +            r"transformer_blocks.\1.attn1.to_q.\2", | 
|  | 58 | +            r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$": | 
|  | 59 | +            r"transformer_blocks.\1.attn1.to_k.\2", | 
|  | 60 | +            r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$": | 
|  | 61 | +            r"transformer_blocks.\1.attn1.to_v.\2", | 
|  | 62 | +            r"^transformer_blocks\.(\d+)\.attn1\.to_out\.(.*)$": | 
|  | 63 | +            r"transformer_blocks.\1.attn1.to_out.\2", | 
|  | 64 | +            r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$": | 
|  | 65 | +            r"transformer_blocks.\1.attn2.to_q.\2", | 
|  | 66 | +            r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$": | 
|  | 67 | +            r"transformer_blocks.\1.attn2.to_k.\2", | 
|  | 68 | +            r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$": | 
|  | 69 | +            r"transformer_blocks.\1.attn2.to_v.\2", | 
|  | 70 | +            r"^transformer_blocks\.(\d+)\.attn2\.to_out\.(.*)$": | 
|  | 71 | +            r"transformer_blocks.\1.attn2.to_out.\2", | 
|  | 72 | +            r"^transformer_blocks\.(\d+)\.ff\.(.*)$": | 
|  | 73 | +            r"transformer_blocks.\1.ff.\2", | 
|  | 74 | +        }) | 
|  | 75 | + | 
|  | 76 | +    # Cosmos-specific config parameters based on transformer_cosmos.py | 
|  | 77 | +    in_channels: int = 16 | 
|  | 78 | +    out_channels: int = 16 | 
|  | 79 | +    num_attention_heads: int = 16 | 
|  | 80 | +    attention_head_dim: int = 128 | 
|  | 81 | +    num_layers: int = 28 | 
|  | 82 | +    mlp_ratio: float = 4.0 | 
|  | 83 | +    text_embed_dim: int = 1024 | 
|  | 84 | +    adaln_lora_dim: int = 256 | 
|  | 85 | +    max_size: tuple[int, int, int] = (128, 240, 240) | 
|  | 86 | +    patch_size: tuple[int, int, int] = (1, 2, 2) | 
|  | 87 | +    rope_scale: tuple[float, float, float] = (1.0, 3.0, 3.0) | 
|  | 88 | +    concat_padding_mask: bool = True | 
|  | 89 | +    extra_pos_embed_type: str | None = None | 
|  | 90 | +    qk_norm: str = "rms_norm" | 
|  | 91 | +    eps: float = 1e-6 | 
|  | 92 | +    exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"]) | 
|  | 93 | + | 
|  | 94 | +    def __post_init__(self): | 
|  | 95 | +        super().__post_init__() | 
|  | 96 | +        self.out_channels = self.out_channels or self.in_channels | 
|  | 97 | +        self.hidden_size = self.num_attention_heads * self.attention_head_dim | 
|  | 98 | +        self.num_channels_latents = self.in_channels | 
|  | 99 | + | 
|  | 100 | + | 
|  | 101 | +@dataclass | 
|  | 102 | +class CosmosVideoConfig(DiTConfig): | 
|  | 103 | +    arch_config: DiTArchConfig = field(default_factory=CosmosArchConfig) | 
|  | 104 | +    prefix: str = "Cosmos" | 
0 commit comments