From 7118caecf2f1e8766212c655a11ba2456c085704 Mon Sep 17 00:00:00 2001 From: kevin314 Date: Sat, 16 Aug 2025 10:52:25 +0000 Subject: [PATCH 1/5] Add cosmos2 i2v pipeline --- fastvideo/configs/models/dits/__init__.py | 6 +- fastvideo/configs/models/dits/cosmos.py | 104 +++ fastvideo/configs/models/encoders/__init__.py | 5 +- fastvideo/configs/models/encoders/t5.py | 23 + fastvideo/configs/models/vaes/__init__.py | 2 + fastvideo/configs/models/vaes/cosmosvae.py | 87 +++ fastvideo/configs/pipelines/__init__.py | 4 +- fastvideo/configs/pipelines/cosmos.py | 66 ++ fastvideo/configs/pipelines/registry.py | 3 + fastvideo/configs/sample/cosmos.py | 18 + fastvideo/image_processor.py | 195 +++++ fastvideo/layers/layernorm.py | 16 + fastvideo/layers/rotary_embedding.py | 53 ++ fastvideo/layers/visual_embedding.py | 76 ++ fastvideo/models/dits/cosmos.py | 726 ++++++++++++++++++ fastvideo/models/encoders/t5.py | 12 +- fastvideo/models/registry.py | 5 +- .../scheduling_flow_match_euler_discrete.py | 33 +- fastvideo/pipelines/basic/cosmos/__init__.py | 0 .../pipelines/basic/cosmos/cosmos_pipeline.py | 84 ++ fastvideo/pipelines/pipeline_registry.py | 1 + fastvideo/pipelines/stages/__init__.py | 8 +- fastvideo/pipelines/stages/decoding.py | 11 +- fastvideo/pipelines/stages/denoising.py | 287 ++++++- .../pipelines/stages/input_validation.py | 2 +- .../pipelines/stages/latent_preparation.py | 255 +++++- fastvideo/pipelines/stages/text_encoding.py | 13 + fastvideo/pipelines/stages/utils.py | 79 ++ fastvideo/tests/encoders/test_t5_encoder.py | 124 ++- fastvideo/tests/transformers/test_cosmos.py | 257 +++++++ fastvideo/worker/multiproc_executor.py | 7 + test_fastvideo_pipeline.py | 74 ++ 32 files changed, 2582 insertions(+), 54 deletions(-) create mode 100644 fastvideo/configs/models/dits/cosmos.py create mode 100644 fastvideo/configs/models/vaes/cosmosvae.py create mode 100644 fastvideo/configs/pipelines/cosmos.py create mode 100644 fastvideo/configs/sample/cosmos.py create mode 100644 fastvideo/image_processor.py create mode 100644 fastvideo/models/dits/cosmos.py create mode 100644 fastvideo/pipelines/basic/cosmos/__init__.py create mode 100644 fastvideo/pipelines/basic/cosmos/cosmos_pipeline.py create mode 100644 fastvideo/pipelines/stages/utils.py create mode 100644 fastvideo/tests/transformers/test_cosmos.py create mode 100644 test_fastvideo_pipeline.py diff --git a/fastvideo/configs/models/dits/__init__.py b/fastvideo/configs/models/dits/__init__.py index 72271a525..0abc716e2 100644 --- a/fastvideo/configs/models/dits/__init__.py +++ b/fastvideo/configs/models/dits/__init__.py @@ -1,5 +1,9 @@ +from fastvideo.configs.models.dits.cosmos import CosmosVideoConfig from fastvideo.configs.models.dits.hunyuanvideo import HunyuanVideoConfig from fastvideo.configs.models.dits.stepvideo import StepVideoConfig from fastvideo.configs.models.dits.wanvideo import WanVideoConfig -__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig"] +__all__ = [ + "HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig", + "CosmosVideoConfig" +] diff --git a/fastvideo/configs/models/dits/cosmos.py b/fastvideo/configs/models/dits/cosmos.py new file mode 100644 index 000000000..b76e67ed9 --- /dev/null +++ b/fastvideo/configs/models/dits/cosmos.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def is_transformer_blocks(n: str, m) -> bool: + return "transformer_blocks" in n and str.isdigit(n.split(".")[-1]) + + +@dataclass +class CosmosArchConfig(DiTArchConfig): + _fsdp_shard_conditions: list = field( + default_factory=lambda: [is_transformer_blocks]) + + param_names_mapping: dict = field( + default_factory=lambda: { + r"^patch_embed\.(.*)$": r"patch_embed.\1", + r"^time_embed\.time_proj\.(.*)$": r"time_embed.time_proj.\1", + r"^time_embed\.t_embedder\.(.*)$": r"time_embed.t_embedder.\1", + r"^time_embed\.norm\.(.*)$": r"time_embed.norm.\1", + r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$": + r"transformer_blocks.\1.attn1.to_q.\2", + r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$": + r"transformer_blocks.\1.attn1.to_k.\2", + r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$": + r"transformer_blocks.\1.attn1.to_v.\2", + r"^transformer_blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$": + r"transformer_blocks.\1.attn1.to_out.\2", + r"^transformer_blocks\.(\d+)\.attn1\.norm_q\.(.*)$": + r"transformer_blocks.\1.attn1.norm_q.\2", + r"^transformer_blocks\.(\d+)\.attn1\.norm_k\.(.*)$": + r"transformer_blocks.\1.attn1.norm_k.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$": + r"transformer_blocks.\1.attn2.to_q.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$": + r"transformer_blocks.\1.attn2.to_k.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$": + r"transformer_blocks.\1.attn2.to_v.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$": + r"transformer_blocks.\1.attn2.to_out.\2", + r"^transformer_blocks\.(\d+)\.attn2\.norm_q\.(.*)$": + r"transformer_blocks.\1.attn2.norm_q.\2", + r"^transformer_blocks\.(\d+)\.attn2\.norm_k\.(.*)$": + r"transformer_blocks.\1.attn2.norm_k.\2", + r"^transformer_blocks\.(\d+)\.ff\.net\.0\.proj\.(.*)$": + r"transformer_blocks.\1.ff.fc_in.\2", + r"^transformer_blocks\.(\d+)\.ff\.net\.2\.(.*)$": + r"transformer_blocks.\1.ff.fc_out.\2", + r"^norm_out\.(.*)$": r"norm_out.\1", + r"^proj_out\.(.*)$": r"proj_out.\1", + }) + + lora_param_names_mapping: dict = field( + default_factory=lambda: { + r"^transformer_blocks\.(\d+)\.attn1\.to_q\.(.*)$": + r"transformer_blocks.\1.attn1.to_q.\2", + r"^transformer_blocks\.(\d+)\.attn1\.to_k\.(.*)$": + r"transformer_blocks.\1.attn1.to_k.\2", + r"^transformer_blocks\.(\d+)\.attn1\.to_v\.(.*)$": + r"transformer_blocks.\1.attn1.to_v.\2", + r"^transformer_blocks\.(\d+)\.attn1\.to_out\.(.*)$": + r"transformer_blocks.\1.attn1.to_out.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_q\.(.*)$": + r"transformer_blocks.\1.attn2.to_q.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_k\.(.*)$": + r"transformer_blocks.\1.attn2.to_k.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_v\.(.*)$": + r"transformer_blocks.\1.attn2.to_v.\2", + r"^transformer_blocks\.(\d+)\.attn2\.to_out\.(.*)$": + r"transformer_blocks.\1.attn2.to_out.\2", + r"^transformer_blocks\.(\d+)\.ff\.(.*)$": + r"transformer_blocks.\1.ff.\2", + }) + + # Cosmos-specific config parameters based on transformer_cosmos.py + in_channels: int = 16 + out_channels: int = 16 + num_attention_heads: int = 16 + attention_head_dim: int = 128 + num_layers: int = 28 + mlp_ratio: float = 4.0 + text_embed_dim: int = 1024 + adaln_lora_dim: int = 256 + max_size: tuple[int, int, int] = (128, 240, 240) + patch_size: tuple[int, int, int] = (1, 2, 2) + rope_scale: tuple[float, float, float] = (1.0, 3.0, 3.0) + concat_padding_mask: bool = True + extra_pos_embed_type: str | None = None + qk_norm: str = "rms_norm" + eps: float = 1e-6 + exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"]) + + def __post_init__(self): + super().__post_init__() + self.out_channels = self.out_channels or self.in_channels + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.num_channels_latents = self.in_channels + + +@dataclass +class CosmosVideoConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=CosmosArchConfig) + prefix: str = "Cosmos" diff --git a/fastvideo/configs/models/encoders/__init__.py b/fastvideo/configs/models/encoders/__init__.py index e56dd3a3c..7f49284f7 100644 --- a/fastvideo/configs/models/encoders/__init__.py +++ b/fastvideo/configs/models/encoders/__init__.py @@ -5,10 +5,11 @@ from fastvideo.configs.models.encoders.clip import ( CLIPTextConfig, CLIPVisionConfig, WAN2_1ControlCLIPVisionConfig) from fastvideo.configs.models.encoders.llama import LlamaConfig -from fastvideo.configs.models.encoders.t5 import T5Config +from fastvideo.configs.models.encoders.t5 import T5Config, T5LargeConfig __all__ = [ "EncoderConfig", "TextEncoderConfig", "ImageEncoderConfig", "BaseEncoderOutput", "CLIPTextConfig", "CLIPVisionConfig", - "WAN2_1ControlCLIPVisionConfig", "LlamaConfig", "T5Config" + "WAN2_1ControlCLIPVisionConfig", "LlamaConfig", "T5Config", + "T5LargeConfig" ] diff --git a/fastvideo/configs/models/encoders/t5.py b/fastvideo/configs/models/encoders/t5.py index 70649551b..c1de3609c 100644 --- a/fastvideo/configs/models/encoders/t5.py +++ b/fastvideo/configs/models/encoders/t5.py @@ -70,8 +70,31 @@ def __post_init__(self): } +@dataclass +class T5LargeArchConfig(T5ArchConfig): + """T5 Large architecture config with parameters for your specific model.""" + d_model: int = 1024 + d_kv: int = 128 + d_ff: int = 65536 + num_layers: int = 24 + num_decoder_layers: int | None = 24 + num_heads: int = 128 + decoder_start_token_id: int = 0 + n_positions: int = 512 + task_specific_params: dict | None = None + + @dataclass class T5Config(TextEncoderConfig): arch_config: TextEncoderArchConfig = field(default_factory=T5ArchConfig) prefix: str = "t5" + + +@dataclass +class T5LargeConfig(TextEncoderConfig): + """T5 Large configuration for your specific model.""" + arch_config: TextEncoderArchConfig = field( + default_factory=T5LargeArchConfig) + + prefix: str = "t5" diff --git a/fastvideo/configs/models/vaes/__init__.py b/fastvideo/configs/models/vaes/__init__.py index 700c8de1b..12bf5c609 100644 --- a/fastvideo/configs/models/vaes/__init__.py +++ b/fastvideo/configs/models/vaes/__init__.py @@ -1,3 +1,4 @@ +from fastvideo.configs.models.vaes.cosmosvae import CosmosVAEConfig from fastvideo.configs.models.vaes.hunyuanvae import HunyuanVAEConfig from fastvideo.configs.models.vaes.stepvideovae import StepVideoVAEConfig from fastvideo.configs.models.vaes.wanvae import WanVAEConfig @@ -6,4 +7,5 @@ "HunyuanVAEConfig", "WanVAEConfig", "StepVideoVAEConfig", + "CosmosVAEConfig", ] diff --git a/fastvideo/configs/models/vaes/cosmosvae.py b/fastvideo/configs/models/vaes/cosmosvae.py new file mode 100644 index 000000000..4680986f3 --- /dev/null +++ b/fastvideo/configs/models/vaes/cosmosvae.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +import torch + +from fastvideo.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class CosmosVAEArchConfig(VAEArchConfig): + _name_or_path: str = "" + base_dim: int = 96 + z_dim: int = 16 + dim_mult: tuple[int, ...] = (1, 2, 4, 4) + num_res_blocks: int = 2 + attn_scales: tuple[float, ...] = () + temperal_downsample: tuple[bool, ...] = (False, True, True) + dropout: float = 0.0 + decoder_base_dim: int | None = None + is_residual: bool = False + in_channels: int = 3 + out_channels: int = 3 + patch_size: int | None = None + scale_factor_temporal: int = 4 + scale_factor_spatial: int = 8 + clip_output: bool = True + latents_mean: tuple[float, ...] = ( + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ) + latents_std: tuple[float, ...] = ( + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ) + temporal_compression_ratio = 4 + spatial_compression_ratio = 8 + + def __post_init__(self): + self.scaling_factor: torch.Tensor = 1.0 / torch.tensor( + self.latents_std).view(1, self.z_dim, 1, 1, 1) + self.shift_factor: torch.Tensor = torch.tensor(self.latents_mean).view( + 1, self.z_dim, 1, 1, 1) + self.temporal_compression_ratio = self.scale_factor_temporal + self.spatial_compression_ratio = self.scale_factor_spatial + + +@dataclass +class CosmosVAEConfig(VAEConfig): + arch_config: CosmosVAEArchConfig = field( + default_factory=CosmosVAEArchConfig) + use_feature_cache: bool = True + + use_tiling: bool = False + use_temporal_tiling: bool = False + use_parallel_tiling: bool = False + + def __post_init__(self): + self.blend_num_frames = (self.tile_sample_min_num_frames - + self.tile_sample_stride_num_frames) * 2 diff --git a/fastvideo/configs/pipelines/__init__.py b/fastvideo/configs/pipelines/__init__.py index 6ff503848..8bbb0a60c 100644 --- a/fastvideo/configs/pipelines/__init__.py +++ b/fastvideo/configs/pipelines/__init__.py @@ -1,5 +1,6 @@ from fastvideo.configs.pipelines.base import (PipelineConfig, SlidingTileAttnConfig) +from fastvideo.configs.pipelines.cosmos import CosmosConfig from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig from fastvideo.configs.pipelines.registry import ( get_pipeline_config_cls_from_name) @@ -12,5 +13,6 @@ "HunyuanConfig", "FastHunyuanConfig", "PipelineConfig", "SlidingTileAttnConfig", "WanT2V480PConfig", "WanI2V480PConfig", "WanT2V720PConfig", "WanI2V720PConfig", "StepVideoT2VConfig", - "SelfForcingWanT2V480PConfig", "get_pipeline_config_cls_from_name" + "SelfForcingWanT2V480PConfig", "CosmosConfig", + "get_pipeline_config_cls_from_name" ] diff --git a/fastvideo/configs/pipelines/cosmos.py b/fastvideo/configs/pipelines/cosmos.py new file mode 100644 index 000000000..3ca78fe0f --- /dev/null +++ b/fastvideo/configs/pipelines/cosmos.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Callable +from dataclasses import dataclass, field + +import torch + +from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig +from fastvideo.configs.models.dits import CosmosVideoConfig +from fastvideo.configs.models.encoders import BaseEncoderOutput, T5LargeConfig +from fastvideo.configs.models.vaes import CosmosVAEConfig +from fastvideo.configs.pipelines.base import PipelineConfig + + +def t5_large_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor: + """Postprocess T5 Large text encoder outputs for Cosmos pipeline. + + Return raw last_hidden_state without truncation/padding. + """ + hidden_state = outputs.last_hidden_state + + if hidden_state is None: + raise ValueError("T5 Large outputs missing last_hidden_state") + + nan_count = torch.isnan(hidden_state).sum() + if nan_count > 0: + hidden_state = hidden_state.masked_fill(torch.isnan(hidden_state), 0.0) + + return hidden_state + + +@dataclass +class CosmosConfig(PipelineConfig): + """Configuration for Cosmos2 Video2World pipeline matching diffusers.""" + + dit_config: DiTConfig = field(default_factory=CosmosVideoConfig) + + vae_config: VAEConfig = field(default_factory=CosmosVAEConfig) + + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (T5LargeConfig(), )) + postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor], + ...] = field(default_factory=lambda: + (t5_large_postprocess_text, )) + + dit_precision: str = "bf16" + vae_precision: str = "fp16" + text_encoder_precisions: tuple[str, ...] = field( + default_factory=lambda: ("bf16", )) + + conditioning_strategy: str = "frame_replace" + min_num_conditional_frames: int = 1 + max_num_conditional_frames: int = 2 + sigma_conditional: float = 0.0001 + sigma_data: float = 1.0 + state_ch: int = 16 + state_t: int = 24 + text_encoder_class: str = "T5" + + embedded_cfg_scale: int = 6 + flow_shift: float = 1.0 + + def __post_init__(self): + self.vae_config.load_encoder = True + self.vae_config.load_decoder = True + + self._vae_latent_dim = 16 diff --git a/fastvideo/configs/pipelines/registry.py b/fastvideo/configs/pipelines/registry.py index 8803d0765..9b7e0bf02 100644 --- a/fastvideo/configs/pipelines/registry.py +++ b/fastvideo/configs/pipelines/registry.py @@ -5,6 +5,7 @@ from collections.abc import Callable from fastvideo.configs.pipelines.base import PipelineConfig +from fastvideo.configs.pipelines.cosmos import CosmosConfig from fastvideo.configs.pipelines.hunyuan import FastHunyuanConfig, HunyuanConfig from fastvideo.configs.pipelines.stepvideo import StepVideoT2VConfig @@ -40,6 +41,7 @@ "Wan-AI/Wan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_Config, "Wan-AI/Wan2.2-T2V-A14B-Diffusers": Wan2_2_T2V_A14B_Config, "Wan-AI/Wan2.2-I2V-A14B-Diffusers": Wan2_2_I2V_A14B_Config, + "nvidia/Cosmos-Predict2-2B-Video2World": CosmosConfig, # Add other specific weight variants } @@ -51,6 +53,7 @@ "wandmdpipeline": lambda id: "wandmdpipeline" in id.lower(), "wancausaldmdpipeline": lambda id: "wancausaldmdpipeline" in id.lower(), "stepvideo": lambda id: "stepvideo" in id.lower(), + "cosmos": lambda id: "cosmos" in id.lower(), # Add other pipeline architecture detectors } diff --git a/fastvideo/configs/sample/cosmos.py b/fastvideo/configs/sample/cosmos.py new file mode 100644 index 000000000..32886151e --- /dev/null +++ b/fastvideo/configs/sample/cosmos.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from fastvideo.configs.sample.base import SamplingParam + + +@dataclass +class Cosmos_Predict2_2B_Video2World_SamplingParam(SamplingParam): + # Video parameters + height: int = 704 + width: int = 1280 + num_frames: int = 93 + fps: int = 16 + + # Denoising stage + guidance_scale: float = 7.0 + negative_prompt: str = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." + num_inference_steps: int = 35 diff --git a/fastvideo/image_processor.py b/fastvideo/image_processor.py new file mode 100644 index 000000000..3483631f5 --- /dev/null +++ b/fastvideo/image_processor.py @@ -0,0 +1,195 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Minimal image processing utilities for FastVideo. +This module provides lightweight image preprocessing without external dependencies beyond PyTorch/NumPy/PIL. +""" + +import numpy as np +import PIL.Image +import torch + + +class ImageProcessor: + """ + Minimal image processor for video frame preprocessing. + + This is a lightweight alternative to diffusers.VideoProcessor that handles: + - PIL image to tensor conversion + - Resizing to specified dimensions + - Normalization to [-1, 1] range + + Args: + vae_scale_factor: The VAE scale factor used to ensure dimensions are multiples of this value. + """ + + def __init__(self, vae_scale_factor: int = 8) -> None: + self.vae_scale_factor = vae_scale_factor + + def preprocess( + self, + image: PIL.Image.Image | np.ndarray | torch.Tensor, + height: int | None = None, + width: int | None = None, + ) -> torch.Tensor: + """ + Preprocess an image to a normalized torch tensor. + + Args: + image: Input image (PIL Image, NumPy array, or torch tensor) + height: Target height. If None, uses image's original height. + width: Target width. If None, uses image's original width. + + Returns: + torch.Tensor: Normalized tensor of shape (1, 3, height, width) or (1, 1, height, width) for grayscale, + with values in range [-1, 1]. + """ + # Handle different input types + if isinstance(image, PIL.Image.Image): + return self._preprocess_pil(image, height, width) + elif isinstance(image, np.ndarray): + return self._preprocess_numpy(image, height, width) + elif isinstance(image, torch.Tensor): + return self._preprocess_tensor(image, height, width) + else: + raise ValueError( + f"Unsupported image type: {type(image)}. " + "Supported types: PIL.Image.Image, np.ndarray, torch.Tensor") + + def _preprocess_pil( + self, + image: PIL.Image.Image, + height: int | None = None, + width: int | None = None, + ) -> torch.Tensor: + """Preprocess a PIL image.""" + if height is None: + height = image.height + if width is None: + width = image.width + + height = height - (height % self.vae_scale_factor) + width = width - (width % self.vae_scale_factor) + + image = image.resize((width, height), + resample=PIL.Image.Resampling.LANCZOS) + + image_np = np.array(image, dtype=np.float32) / 255.0 + + if image_np.ndim == 2: # Grayscale + image_np = np.expand_dims(image_np, axis=-1) + + return self._normalize_to_tensor(image_np) + + def _preprocess_numpy( + self, + image: np.ndarray, + height: int | None = None, + width: int | None = None, + ) -> torch.Tensor: + """Preprocess a numpy array.""" + # Determine target dimensions if not provided + if image.ndim == 3: + img_height, img_width = image.shape[:2] + elif image.ndim == 2: + img_height, img_width = image.shape + else: + raise ValueError(f"Expected 2D or 3D array, got {image.ndim}D") + + if height is None: + height = img_height + if width is None: + width = img_width + + height = height - (height % self.vae_scale_factor) + width = width - (width % self.vae_scale_factor) + + if image.dtype == np.uint8: + pil_image = PIL.Image.fromarray(image) + else: + # Assume normalized [0, 1] or similar + if image.max() <= 1.0: + image_uint8 = (image * 255).astype(np.uint8) + else: + image_uint8 = image.astype(np.uint8) + pil_image = PIL.Image.fromarray(image_uint8) + + pil_image = pil_image.resize((width, height), + resample=PIL.Image.Resampling.LANCZOS) + image_np = np.array(pil_image, dtype=np.float32) / 255.0 + + # Ensure 3D shape + if image_np.ndim == 2: + image_np = np.expand_dims(image_np, axis=-1) + + return self._normalize_to_tensor(image_np) + + def _preprocess_tensor( + self, + image: torch.Tensor, + height: int | None = None, + width: int | None = None, + ) -> torch.Tensor: + """Preprocess a torch tensor.""" + # Determine target dimensions + if image.ndim == 3: # (H, W, C) or (C, H, W) + if image.shape[0] in (1, 3, 4): # Likely (C, H, W) + img_height, img_width = image.shape[1], image.shape[2] + else: # Likely (H, W, C) + img_height, img_width = image.shape[0], image.shape[1] + elif image.ndim == 2: # (H, W) + img_height, img_width = image.shape + else: + raise ValueError(f"Expected 2D or 3D tensor, got {image.ndim}D") + + if height is None: + height = img_height + if width is None: + width = img_width + + height = height - (height % self.vae_scale_factor) + width = width - (width % self.vae_scale_factor) + + if image.ndim == 2: + image = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W) + elif image.ndim == 3: + if image.shape[0] in (1, 3, 4): # (C, H, W) + image = image.unsqueeze(0) # (1, C, H, W) + else: # (H, W, C) - need to rearrange + image = image.permute(2, 0, 1).unsqueeze(0) # (1, C, H, W) + + image = torch.nn.functional.interpolate(image, + size=(height, width), + mode="bilinear", + align_corners=False) + + if image.max() > 1.0: # Assume [0, 255] range + image = image / 255.0 + + image = 2.0 * image - 1.0 + + return image + + def _normalize_to_tensor(self, image_np: np.ndarray) -> torch.Tensor: + """ + Convert normalized numpy array [0, 1] to torch tensor [-1, 1]. + + Args: + image_np: NumPy array with shape (H, W) or (H, W, C) with values in [0, 1] + + Returns: + torch.Tensor: Shape (1, C, H, W) or (1, 1, H, W) with values in [-1, 1] + """ + # Convert to tensor + if image_np.ndim == 2: # (H, W) - grayscale + tensor = torch.from_numpy(image_np).unsqueeze(0).unsqueeze( + 0) # (1, 1, H, W) + elif image_np.ndim == 3: # (H, W, C) + tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze( + 0) # (1, C, H, W) + else: + raise ValueError(f"Expected 2D or 3D array, got {image_np.ndim}D") + + # Normalize to [-1, 1] + tensor = 2.0 * tensor - 1.0 + + return tensor diff --git a/fastvideo/layers/layernorm.py b/fastvideo/layers/layernorm.py index 091ab841a..7077418b9 100644 --- a/fastvideo/layers/layernorm.py +++ b/fastvideo/layers/layernorm.py @@ -40,6 +40,22 @@ def __init__( if self.has_weight: self.weight = nn.Parameter(self.weight) + def forward_diffusers(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Forward method that matches Diffusers RMSNorm implementation exactly.""" + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + + if self.has_weight and self.weight is not None: + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + else: + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + # if we do fully_shard(model.layer_norm), and we call layer_form.forward_native(input) instead of layer_norm(input), # we need to call model.layer_norm.register_fsdp_forward_method(model, "forward_native") to make sure fsdp2 hooks are triggered # for mixed precision and cpu offloading diff --git a/fastvideo/layers/rotary_embedding.py b/fastvideo/layers/rotary_embedding.py index 1270f3151..0a1c1dc28 100644 --- a/fastvideo/layers/rotary_embedding.py +++ b/fastvideo/layers/rotary_embedding.py @@ -47,6 +47,59 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: return x.flatten(-2) +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> torch.Tensor: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + # Match Diffusers broadcasting (sequence_dim=2 case) + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, + 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen, CogView4 and Cosmos + x_real, x_imag = x.reshape(*x.shape[:-1], 2, + -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError( + f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2." + ) + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + x_rotated = torch.view_as_complex(x.float().reshape( + *x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) + + def _apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, diff --git a/fastvideo/layers/visual_embedding.py b/fastvideo/layers/visual_embedding.py index 07df51de3..9d9f0e20e 100644 --- a/fastvideo/layers/visual_embedding.py +++ b/fastvideo/layers/visual_embedding.py @@ -177,3 +177,79 @@ def unpatchify(x, t, h, w, patch_size, channels) -> torch.Tensor: imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) return imgs + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> torch.Tensor: + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class Timesteps(nn.Module): + + def __init__(self, + num_channels: int, + flip_sin_to_cos: bool, + downscale_freq_shift: float, + scale: int = 1): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb \ No newline at end of file diff --git a/fastvideo/models/dits/cosmos.py b/fastvideo/models/dits/cosmos.py new file mode 100644 index 000000000..fbbdbced0 --- /dev/null +++ b/fastvideo/models/dits/cosmos.py @@ -0,0 +1,726 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import Any + +import numpy as np +import torch +import torch.nn as nn + +from fastvideo.attention import DistributedAttention, LocalAttention +from fastvideo.configs.models.dits.cosmos import CosmosVideoConfig +from fastvideo.forward_context import get_forward_context +from fastvideo.layers.layernorm import RMSNorm +from fastvideo.layers.linear import ReplicatedLinear +from fastvideo.layers.mlp import MLP +from fastvideo.layers.rotary_embedding import apply_rotary_emb +from fastvideo.layers.visual_embedding import Timesteps +from fastvideo.models.dits.base import BaseDiT +from fastvideo.platforms import AttentionBackendEnum + + +class CosmosPatchEmbed(nn.Module): + + def __init__(self, + in_channels: int, + out_channels: int, + patch_size: tuple[int, int, int], + bias: bool = True) -> None: + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Linear(in_channels * patch_size[0] * patch_size[1] * + patch_size[2], + out_channels, + bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + hidden_states = hidden_states.reshape(batch_size, num_channels, + num_frames // p_t, p_t, + height // p_h, p_h, width // p_w, + p_w) + hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, + 7).flatten(4, 7) + hidden_states = self.proj(hidden_states) + return hidden_states + + +class CosmosTimestepEmbedding(nn.Module): + + def __init__(self, in_features: int, out_features: int) -> None: + super().__init__() + self.linear_1 = nn.Linear(in_features, out_features, bias=False) + self.activation = nn.SiLU() + self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False) + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + emb = self.linear_1(timesteps) + emb = self.activation(emb) + emb = self.linear_2(emb) + return emb + + +class CosmosEmbedding(nn.Module): + + def __init__(self, embedding_dim: int, condition_dim: int) -> None: + super().__init__() + + self.time_proj = Timesteps(embedding_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0.0) + self.t_embedder = CosmosTimestepEmbedding(embedding_dim, condition_dim) + self.norm = RMSNorm(embedding_dim, eps=1e-6) + + def forward(self, hidden_states: torch.Tensor, + timestep: torch.LongTensor) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep).type_as(hidden_states) + temb = self.t_embedder(timesteps_proj) + embedded_timestep = self.norm(timesteps_proj) + return temb, embedded_timestep + + +class CosmosAdaLayerNorm(nn.Module): + + def __init__(self, in_features: int, hidden_features: int) -> None: + super().__init__() + self.embedding_dim = in_features + + self.activation = nn.SiLU() + self.norm = nn.LayerNorm(in_features, + elementwise_affine=False, + eps=1e-6) + self.linear_1 = nn.Linear(in_features, hidden_features, bias=False) + self.linear_2 = nn.Linear(hidden_features, 2 * in_features, bias=False) + + def forward(self, + hidden_states: torch.Tensor, + embedded_timestep: torch.Tensor, + temb: torch.Tensor | None = None) -> torch.Tensor: + embedded_timestep = self.activation(embedded_timestep) + embedded_timestep = self.linear_1(embedded_timestep) + embedded_timestep = self.linear_2(embedded_timestep) + + if temb is not None: + embedded_timestep = embedded_timestep + temb[..., :2 * + self.embedding_dim] + + shift, scale = embedded_timestep.chunk(2, dim=-1) + with torch.autocast(device_type="cuda", enabled=False): + hidden_states = self.norm(hidden_states) + + if embedded_timestep.ndim == 2: + shift, scale = (x.unsqueeze(1) for x in (shift, scale)) + + hidden_states = hidden_states * (1 + scale) + shift + return hidden_states + + +class CosmosAdaLayerNormZero(nn.Module): + + def __init__(self, + in_features: int, + hidden_features: int | None = None) -> None: + super().__init__() + + self.norm = nn.LayerNorm(in_features, + elementwise_affine=False, + eps=1e-6) + self.activation = nn.SiLU() + + if hidden_features is None: + self.linear_1 = nn.Identity() + else: + self.linear_1 = nn.Linear(in_features, hidden_features, bias=False) + + self.linear_2 = nn.Linear(hidden_features, 3 * in_features, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + embedded_timestep: torch.Tensor, + temb: torch.Tensor | None = None, + ) -> torch.Tensor: + embedded_timestep = self.activation(embedded_timestep) + embedded_timestep = self.linear_1(embedded_timestep) + embedded_timestep = self.linear_2(embedded_timestep) + + if temb is not None: + embedded_timestep = embedded_timestep + temb + + shift, scale, gate = embedded_timestep.chunk(3, dim=-1) + + with torch.autocast(device_type="cuda", enabled=False): + hidden_states = self.norm(hidden_states) + + if embedded_timestep.ndim == 2: + shift, scale, gate = (x.unsqueeze(1) for x in (shift, scale, gate)) + + hidden_states = hidden_states * (1 + scale) + shift + return hidden_states, gate + + +class CosmosSelfAttention(nn.Module): + + def __init__(self, + dim: int, + num_heads: int, + qk_norm=True, + eps=1e-6, + prefix: str = "") -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + self.eps = eps + + # layers - use standard PyTorch layers when using torch backend + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_k = nn.Linear(dim, dim, bias=False) + self.to_v = nn.Linear(dim, dim, bias=False) + self.to_out = nn.Linear(dim, dim, bias=False) + self.dropout = nn.Dropout(0.0) + + self.norm_q = RMSNorm(self.head_dim, + eps=eps) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(self.head_dim, + eps=eps) if qk_norm else nn.Identity() + + def forward(self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None) -> torch.Tensor: + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + # Get QKV + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + # Reshape for multi-head attention + query = query.unflatten(2, (self.num_heads, -1)).transpose(1, 2) + key = key.unflatten(2, (self.num_heads, -1)).transpose(1, 2) + value = value.unflatten(2, (self.num_heads, -1)).transpose(1, 2) + + # Apply normalization + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + # Apply RoPE if provided + if image_rotary_emb is not None: + query = apply_rotary_emb(query, + image_rotary_emb, + use_real=True, + use_real_unbind_dim=-2) + key = apply_rotary_emb(key, + image_rotary_emb, + use_real=True, + use_real_unbind_dim=-2) + + # Prepare for GQA (Grouped Query Attention) + if torch.onnx.is_in_onnx_export(): + query_idx = torch.tensor(query.size(3), device=query.device) + key_idx = torch.tensor(key.size(3), device=key.device) + value_idx = torch.tensor(value.size(3), device=value.device) + else: + query_idx = query.size(3) + key_idx = key.size(3) + value_idx = value.size(3) + key = key.repeat_interleave(query_idx // key_idx, dim=3) + value = value.repeat_interleave(query_idx // value_idx, dim=3) + + # Attention computation + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + attn_output = attn_output.transpose(1, 2).flatten(2, 3).type_as(query) + + # Output projection + attn_output = self.to_out(attn_output) + attn_output = self.dropout(attn_output) + + return attn_output + + +class CosmosCrossAttention(nn.Module): + + def __init__(self, + dim: int, + cross_attention_dim: int, + num_heads: int, + qk_norm=True, + eps=1e-6, + prefix: str = "") -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.cross_attention_dim = cross_attention_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + self.eps = eps + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, dim, bias=False) + self.to_out = nn.Linear(dim, dim, bias=False) + self.dropout = nn.Dropout(0.0) + + self.norm_q = RMSNorm(self.head_dim, + eps=eps) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(self.head_dim, + eps=eps) if qk_norm else nn.Identity() + + def forward(self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None) -> torch.Tensor: + + # Get QKV + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + # Reshape for multi-head attention + query = query.unflatten(2, (self.num_heads, -1)).transpose(1, 2) + key = key.unflatten(2, (self.num_heads, -1)).transpose(1, 2) + value = value.unflatten(2, (self.num_heads, -1)).transpose(1, 2) + + # Apply normalization + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + # Prepare for GQA (Grouped Query Attention) + if torch.onnx.is_in_onnx_export(): + query_idx = torch.tensor(query.size(3), device=query.device) + key_idx = torch.tensor(key.size(3), device=key.device) + value_idx = torch.tensor(value.size(3), device=value.device) + else: + query_idx = query.size(3) + key_idx = key.size(3) + value_idx = value.size(3) + key = key.repeat_interleave(query_idx // key_idx, dim=3) + value = value.repeat_interleave(query_idx // value_idx, dim=3) + + # Attention computation + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + attn_output = attn_output.transpose(1, 2).flatten(2, 3).type_as(query) + + # Output projection + attn_output = self.to_out(attn_output) + attn_output = self.dropout(attn_output) + + return attn_output + + +class CosmosTransformerBlock(nn.Module): + + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + mlp_ratio: float = 4.0, + adaln_lora_dim: int = 256, + qk_norm: str = "rms_norm", + out_bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = CosmosAdaLayerNormZero(in_features=hidden_size, + hidden_features=adaln_lora_dim) + self.attn1 = CosmosSelfAttention( + dim=hidden_size, + num_heads=num_attention_heads, + qk_norm=(qk_norm == "rms_norm"), + eps=1e-5, + prefix=f"{prefix}.attn1") + + self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size, + hidden_features=adaln_lora_dim) + self.attn2 = CosmosCrossAttention( + dim=hidden_size, + cross_attention_dim=cross_attention_dim, + num_heads=num_attention_heads, + qk_norm=(qk_norm == "rms_norm"), + eps=1e-5, + prefix=f"{prefix}.attn2") + + self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size, + hidden_features=adaln_lora_dim) + self.ff = MLP(hidden_size, + int(hidden_size * mlp_ratio), + act_type="gelu", + bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + embedded_timestep: torch.Tensor, + temb: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + extra_pos_emb: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if extra_pos_emb is not None: + hidden_states = hidden_states + extra_pos_emb + + norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, + temb) + + attn_output = self.attn1(norm_hidden_states, + image_rotary_emb=image_rotary_emb) + hidden_states = hidden_states + gate * attn_output + + norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, + temb) + attn_output = self.attn2(norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask) + + hidden_states = hidden_states + gate * attn_output + + norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, + temb) + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + gate * ff_output + + return hidden_states + + +class CosmosRotaryPosEmbed(nn.Module): + + def __init__( + self, + hidden_size: int, + max_size: tuple[int, int, int] = (128, 240, 240), + patch_size: tuple[int, int, int] = (1, 2, 2), + base_fps: int = 24, + rope_scale: tuple[float, float, float] = (2.0, 1.0, 1.0), + ) -> None: + super().__init__() + + self.max_size = [ + size // patch + for size, patch in zip(max_size, patch_size, strict=False) + ] + self.patch_size = patch_size + self.base_fps = base_fps + + self.dim_h = hidden_size // 6 * 2 + self.dim_w = hidden_size // 6 * 2 + self.dim_t = hidden_size - self.dim_h - self.dim_w + + self.h_ntk_factor = rope_scale[1]**(self.dim_h / (self.dim_h - 2)) + self.w_ntk_factor = rope_scale[2]**(self.dim_w / (self.dim_w - 2)) + self.t_ntk_factor = rope_scale[0]**(self.dim_t / (self.dim_t - 2)) + + + def forward(self, + hidden_states: torch.Tensor, + fps: int | None = None) -> tuple[torch.Tensor, torch.Tensor]: + fps = 16 + batch_size, num_channels, num_frames, height, width = hidden_states.shape + pe_size = [ + num_frames // self.patch_size[0], height // self.patch_size[1], + width // self.patch_size[2] + ] + device = hidden_states.device + + h_theta = 10000.0 * self.h_ntk_factor + w_theta = 10000.0 * self.w_ntk_factor + t_theta = 10000.0 * self.t_ntk_factor + + seq = torch.arange(max(self.max_size), + device=device, + dtype=torch.float32) + dim_h_range = ( + torch.arange(0, self.dim_h, 2, device=device, + dtype=torch.float32)[:(self.dim_h // 2)] / self.dim_h) + dim_w_range = ( + torch.arange(0, self.dim_w, 2, device=device, + dtype=torch.float32)[:(self.dim_w // 2)] / self.dim_w) + dim_t_range = ( + torch.arange(0, self.dim_t, 2, device=device, + dtype=torch.float32)[:(self.dim_t // 2)] / self.dim_t) + + h_spatial_freqs = 1.0 / (h_theta**dim_h_range) + w_spatial_freqs = 1.0 / (w_theta**dim_w_range) + temporal_freqs = 1.0 / (t_theta**dim_t_range) + + emb_h = torch.outer(seq[:pe_size[1]], + h_spatial_freqs)[None, :, None, :].repeat( + pe_size[0], 1, pe_size[2], 1) + emb_w = torch.outer(seq[:pe_size[2]], + w_spatial_freqs)[None, None, :, :].repeat( + pe_size[0], pe_size[1], 1, 1) + + if fps is None: + emb_t = torch.outer(seq[:pe_size[0]], temporal_freqs) + else: + temporal_scale = seq[:pe_size[0]] / fps * self.base_fps + emb_t = torch.outer(temporal_scale, + temporal_freqs) + + emb_t = emb_t[:, None, None, :].repeat(1, pe_size[1], pe_size[2], 1) + freqs = torch.cat([emb_t, emb_h, emb_w] * 2, dim=-1).flatten(0, + 2).float() + cos = torch.cos(freqs) + sin = torch.sin(freqs) + return cos, sin + + +class CosmosLearnablePositionalEmbed(nn.Module): + + def __init__( + self, + hidden_size: int, + max_size: tuple[int, int, int], + patch_size: tuple[int, int, int], + eps: float = 1e-6, + ) -> None: + super().__init__() + + self.max_size = [ + size // patch + for size, patch in zip(max_size, patch_size, strict=False) + ] + self.patch_size = patch_size + self.eps = eps + + self.pos_emb_t = nn.Parameter(torch.zeros(self.max_size[0], + hidden_size)) + self.pos_emb_h = nn.Parameter(torch.zeros(self.max_size[1], + hidden_size)) + self.pos_emb_w = nn.Parameter(torch.zeros(self.max_size[2], + hidden_size)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + pe_size = [ + num_frames // self.patch_size[0], height // self.patch_size[1], + width // self.patch_size[2] + ] + + emb_t = self.pos_emb_t[:pe_size[0]][None, :, None, None, :].repeat( + batch_size, 1, pe_size[1], pe_size[2], 1) + emb_h = self.pos_emb_h[:pe_size[1]][None, None, :, None, :].repeat( + batch_size, pe_size[0], 1, pe_size[2], 1) + emb_w = self.pos_emb_w[:pe_size[2]][None, None, None, :, :].repeat( + batch_size, pe_size[0], pe_size[1], 1, 1) + emb = emb_t + emb_h + emb_w + emb = emb.flatten(1, 3) + + norm = torch.linalg.vector_norm(emb, + dim=-1, + keepdim=True, + dtype=torch.float32) + norm = torch.add(self.eps, + norm, + alpha=np.sqrt(norm.numel() / emb.numel())) + return (emb / norm).type_as(hidden_states) + + +class CosmosTransformer3DModel(BaseDiT): + _fsdp_shard_conditions = CosmosVideoConfig()._fsdp_shard_conditions + _compile_conditions = CosmosVideoConfig()._compile_conditions + # _supported_attention_backends = CosmosVideoConfig()._supported_attention_backends + param_names_mapping = CosmosVideoConfig().param_names_mapping + lora_param_names_mapping = CosmosVideoConfig().lora_param_names_mapping + + def __init__(self, config: CosmosVideoConfig, hf_config: dict[str, Any]) -> None: + super().__init__(config=config, hf_config=hf_config) + + inner_dim = config.num_attention_heads * config.attention_head_dim + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.in_channels = config.in_channels + self.out_channels = config.out_channels + self.num_channels_latents = config.num_channels_latents + self.patch_size = config.patch_size + self.max_size = config.max_size + self.rope_scale = config.rope_scale + self.concat_padding_mask = config.concat_padding_mask + self.extra_pos_embed_type = config.extra_pos_embed_type + + # 1. Patch Embedding + patch_embed_in_channels = config.in_channels + 1 if config.concat_padding_mask else config.in_channels + self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, + inner_dim, + config.patch_size, + bias=False) + + # 2. Positional Embedding + self.rope = CosmosRotaryPosEmbed(hidden_size=config.attention_head_dim, + max_size=config.max_size, + patch_size=config.patch_size, + rope_scale=config.rope_scale) + + self.learnable_pos_embed = None + if config.extra_pos_embed_type == "learnable": + self.learnable_pos_embed = CosmosLearnablePositionalEmbed( + hidden_size=inner_dim, + max_size=config.max_size, + patch_size=config.patch_size, + ) + + # 3. Time Embedding + self.time_embed = CosmosEmbedding(inner_dim, inner_dim) + + # 4. Transformer Blocks + self.transformer_blocks = nn.ModuleList([ + CosmosTransformerBlock( + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + cross_attention_dim=config.text_embed_dim, + mlp_ratio=config.mlp_ratio, + adaln_lora_dim=config.adaln_lora_dim, + qk_norm=config.qk_norm, + out_bias=False, + prefix=f"{config.prefix}.transformer_blocks.{i}", + ) for i in range(config.num_layers) + ]) + + # 5. Output norm & projection + self.norm_out = CosmosAdaLayerNorm(inner_dim, config.adaln_lora_dim) + self.proj_out = nn.Linear(inner_dim, + config.out_channels * + math.prod(config.patch_size), + bias=False) + + self.gradient_checkpointing = False + + # For TeaCache + self.previous_e0_even = None + self.previous_e0_odd = None + self.previous_residual_even = None + self.previous_residual_odd = None + self.is_even = True + self.should_calc_even = True + self.should_calc_odd = True + self.accumulated_rel_l1_distance_even = 0 + self.accumulated_rel_l1_distance_odd = 0 + self.cnt = 0 + self.__post_init__() + + def forward(self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + attention_mask: torch.Tensor | None = None, + fps: int | None = None, + condition_mask: torch.Tensor | None = None, + padding_mask: torch.Tensor | None = None, + **kwargs) -> torch.Tensor: + forward_batch = get_forward_context().forward_batch + enable_teacache = forward_batch is not None and forward_batch.enable_teacache + + orig_dtype = hidden_states.dtype + if not isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states[0] + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + # 1. Concatenate padding mask if needed & prepare attention mask + if condition_mask is not None: + hidden_states = torch.cat([hidden_states, condition_mask], dim=1) + + if self.concat_padding_mask: + from torchvision import transforms + padding_mask = transforms.functional.resize( + padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + hidden_states = torch.cat( + [hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1 + ) + + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(1).unsqueeze( + 1) # [B, 1, 1, S] + + # 2. Generate positional embeddings + image_rotary_emb = self.rope(hidden_states, fps=fps) + extra_pos_emb = self.learnable_pos_embed( + hidden_states) if self.extra_pos_embed_type == "learnable" else None + + # 3. Patchify input + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states.flatten( + 1, 3) # [B, T, H, W, C] -> [B, THW, C] codespell:ignore + + # 4. Timestep embeddings + if timestep.ndim == 1: + temb, embedded_timestep = self.time_embed(hidden_states, timestep) + elif timestep.ndim == 5: + assert timestep.shape == (batch_size, 1, num_frames, 1, 1), ( + f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}" + ) + timestep = timestep.flatten() + temb, embedded_timestep = self.time_embed(hidden_states, timestep) + # We can do this because num_frames == post_patch_num_frames, as p_t is 1 + temb, embedded_timestep = ( + x.view(batch_size, post_patch_num_frames, 1, 1, + -1).expand(-1, -1, post_patch_height, post_patch_width, + -1).flatten(1, 3) + for x in (temb, embedded_timestep) + ) # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C] codespell:ignore + else: + raise ValueError(f"Unsupported timestep shape: {timestep.shape}") + + # 6. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for i, block in enumerate(self.transformer_blocks): + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + embedded_timestep, + temb, + image_rotary_emb, + extra_pos_emb, + attention_mask, + ) + else: + for i, block in enumerate(self.transformer_blocks): + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + embedded_timestep=embedded_timestep, + temb=temb, + image_rotary_emb=image_rotary_emb, + extra_pos_emb=extra_pos_emb, + attention_mask=attention_mask, + ) + + # 7. Output norm & projection & unpatchify + hidden_states = self.norm_out(hidden_states, embedded_timestep, temb) + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1)) + hidden_states = hidden_states.unflatten( + 1, (post_patch_num_frames, post_patch_height, post_patch_width)) + # NOTE: The permutation order here is not the inverse operation of what happens when patching as usually expected. + # It might be a source of confusion to the reader, but this is correct + hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return hidden_states \ No newline at end of file diff --git a/fastvideo/models/encoders/t5.py b/fastvideo/models/encoders/t5.py index 3ba1cb126..4a8a1711c 100644 --- a/fastvideo/models/encoders/t5.py +++ b/fastvideo/models/encoders/t5.py @@ -180,7 +180,8 @@ def __init__(self, self.qkv_proj = QKVParallelLinear( self.d_model, - self.d_model // self.total_num_heads, + #self.d_model // self.total_num_heads, + self.key_value_proj_dim, self.total_num_heads, self.total_num_kv_heads, bias=False, @@ -198,7 +199,8 @@ def __init__(self, padding_size=self.relative_attention_num_buckets, quant_config=quant_config) self.o = RowParallelLinear( - self.d_model, + #self.d_model, + self.total_num_heads * self.key_value_proj_dim, self.d_model, bias=False, quant_config=quant_config, @@ -297,10 +299,12 @@ def forward( ) -> torch.Tensor: bs, seq_len, _ = hidden_states.shape num_seqs = bs - n, c = self.n_heads, self.d_model // self.total_num_heads + #n, c = self.n_heads, self.d_model // self.total_num_heads + n, c = self.n_heads, self.key_value_proj_dim qkv, _ = self.qkv_proj(hidden_states) # Projection of 'own' hidden state (self-attention). No GQA here. - q, k, v = qkv.split(self.inner_dim, dim=-1) + #q, k, v = qkv.split(self.inner_dim, dim=-1) + q, k, v = qkv.split(self.qkv_proj.output_sizes, dim=-1) q = q.reshape(bs, seq_len, n, c) k = k.reshape(bs, seq_len, n, c) v = v.reshape(bs, seq_len, n, c) diff --git a/fastvideo/models/registry.py b/fastvideo/models/registry.py index 2b919627f..964acb836 100644 --- a/fastvideo/models/registry.py +++ b/fastvideo/models/registry.py @@ -26,7 +26,8 @@ ("dits", "hunyuanvideo", "HunyuanVideoTransformer3DModel"), "WanTransformer3DModel": ("dits", "wanvideo", "WanTransformer3DModel"), "CausalWanTransformer3DModel": ("dits", "causal_wanvideo", "CausalWanTransformer3DModel"), - "StepVideoModel": ("dits", "stepvideo", "StepVideoModel") + "StepVideoModel": ("dits", "stepvideo", "StepVideoModel"), + "CosmosTransformer3DModel": ("dits", "cosmos", "CosmosTransformer3DModel") } _IMAGE_TO_VIDEO_DIT_MODELS = { @@ -39,6 +40,7 @@ "CLIPTextModel": ("encoders", "clip", "CLIPTextModel"), "LlamaModel": ("encoders", "llama", "LlamaModel"), "UMT5EncoderModel": ("encoders", "t5", "UMT5EncoderModel"), + "T5EncoderModel": ("encoders", "t5", "T5EncoderModel"), "STEP1TextEncoder": ("encoders", "stepllm", "STEP1TextEncoder"), "BertModel": ("encoders", "clip", "CLIPTextModel"), } @@ -239,7 +241,6 @@ def register_model( def _raise_for_unsupported(self, architectures: list[str]) -> NoReturn: all_supported_archs = self.get_supported_archs() - if any(arch in all_supported_archs for arch in architectures): raise ValueError( f"Model architectures {architectures} failed " diff --git a/fastvideo/models/schedulers/scheduling_flow_match_euler_discrete.py b/fastvideo/models/schedulers/scheduling_flow_match_euler_discrete.py index 618a300a8..ae8a657e6 100644 --- a/fastvideo/models/schedulers/scheduling_flow_match_euler_discrete.py +++ b/fastvideo/models/schedulers/scheduling_flow_match_euler_discrete.py @@ -88,6 +88,14 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin, The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". stochastic_sampling (`bool`, defaults to False): Whether to use stochastic sampling. + final_sigmas_type (`str`, defaults to "sigma_min"): + The type of final sigmas to use. Either "sigma_min" or "zero". + sigma_max (`float`, *optional*): + The maximum sigma value for the noise schedule. + sigma_min (`float`, *optional*): + The minimum sigma value for the noise schedule. + sigma_data (`float`, *optional*): + The sigma data value for scaling. """ _compatibles: list[Any] = [] @@ -110,6 +118,10 @@ def __init__( use_beta_sigmas: bool | None = False, time_shift_type: str = "exponential", stochastic_sampling: bool = False, + final_sigmas_type: str = "sigma_min", + sigma_max: float | None = None, + sigma_min: float | None = None, + sigma_data: float | None = None, ): if sum([ self.config.use_beta_sigmas, self.config.use_exponential_sigmas, @@ -336,9 +348,9 @@ def set_timesteps( sigmas_array: np.ndarray if sigmas is None: if timesteps_array is None: - timesteps_array = np.linspace(self._sigma_to_t(self.sigma_max), - self._sigma_to_t(self.sigma_min), - num_inference_steps) + t_max = self._sigma_to_t(self.sigma_max) + t_min = self._sigma_to_t(self.sigma_min) + timesteps_array = np.linspace(t_max, t_min, num_inference_steps) sigmas_array = timesteps_array / self.config.num_train_timesteps else: sigmas_array = np.array(sigmas).astype(np.float32) @@ -403,9 +415,7 @@ def set_timesteps( [sigmas_tensor, torch.ones(1, device=sigmas_tensor.device)]) else: - sigmas_tensor = torch.cat( - [sigmas_tensor, - torch.zeros(1, device=sigmas_tensor.device)]) + sigmas_tensor = torch.cat([sigmas_tensor, torch.zeros(1, device=sigmas_tensor.device)]) self.timesteps = timesteps_tensor self.sigmas = sigmas_tensor @@ -505,7 +515,9 @@ def step( next_sigma = lower_sigmas[..., None] dt = current_sigma - next_sigma else: - assert self.step_index is not None, "step_index should not be None" + if self.step_index is None: + self._init_step_index(timestep) + sigma_idx = self.step_index sigma = self.sigmas[sigma_idx] sigma_next = self.sigmas[sigma_idx + 1] @@ -522,7 +534,6 @@ def step( prev_sample = sample + dt * model_output # upon completion increase step index by one - assert self._step_index is not None, "_step_index should not be None" self._step_index += 1 if per_token_timesteps is None: # Cast sample back to model compatible dtype @@ -558,7 +569,7 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, min_inv_rho = sigma_min**(1 / rho) max_inv_rho = sigma_max**(1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho - return sigmas + return torch.from_numpy(sigmas).to(dtype=in_sigmas.dtype, device=in_sigmas.device) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential def _convert_to_exponential(self, in_sigmas: torch.Tensor, @@ -583,7 +594,7 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, sigmas = np.exp( np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) - return sigmas + return torch.from_numpy(sigmas).to(dtype=in_sigmas.dtype, device=in_sigmas.device) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta def _convert_to_beta(self, @@ -614,7 +625,7 @@ def _convert_to_beta(self, for timestep in 1 - np.linspace(0, 1, num_inference_steps) ] ]) - return sigmas + return torch.from_numpy(sigmas).to(dtype=in_sigmas.dtype, device=in_sigmas.device) def _time_shift_exponential( self, mu: float, sigma: float, diff --git a/fastvideo/pipelines/basic/cosmos/__init__.py b/fastvideo/pipelines/basic/cosmos/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fastvideo/pipelines/basic/cosmos/cosmos_pipeline.py b/fastvideo/pipelines/basic/cosmos/cosmos_pipeline.py new file mode 100644 index 000000000..f3b7c9cd7 --- /dev/null +++ b/fastvideo/pipelines/basic/cosmos/cosmos_pipeline.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Cosmos video diffusion pipeline implementation. + +This module contains an implementation of the Cosmos video diffusion pipeline +using the modular pipeline architecture. +""" + +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger +from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler) +from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase +from fastvideo.pipelines.stages import (ConditioningStage, CosmosDenoisingStage, + CosmosLatentPreparationStage, + DecodingStage, InputValidationStage, + TextEncodingStage, + TimestepPreparationStage) + +logger = init_logger(__name__) + + +class Cosmos2VideoToWorldPipeline(ComposedPipelineBase): + + _required_config_modules = [ + "text_encoder", "tokenizer", "vae", "transformer", "scheduler", + "safety_checker" + ] + + def initialize_pipeline(self, fastvideo_args: FastVideoArgs): + self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler( + shift=fastvideo_args.pipeline_config.flow_shift, + use_karras_sigmas=True) + + sigma_max = 80.0 + sigma_min = 0.002 + sigma_data = 1.0 + final_sigmas_type = "sigma_min" + + if self.modules["scheduler"] is not None: + scheduler = self.modules["scheduler"] + scheduler.config.sigma_max = sigma_max + scheduler.config.sigma_min = sigma_min + scheduler.config.sigma_data = sigma_data + scheduler.config.final_sigmas_type = final_sigmas_type + scheduler.sigma_max = sigma_max + scheduler.sigma_min = sigma_min + scheduler.sigma_data = sigma_data + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage(stage_name="input_validation_stage", + stage=InputValidationStage()) + + self.add_stage(stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + )) + + self.add_stage(stage_name="conditioning_stage", + stage=ConditioningStage()) + + self.add_stage(stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"))) + + self.add_stage(stage_name="latent_preparation_stage", + stage=CosmosLatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + vae=self.get_module("vae"))) + + self.add_stage(stage_name="denoising_stage", + stage=CosmosDenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"))) + + self.add_stage(stage_name="decoding_stage", + stage=DecodingStage(vae=self.get_module("vae"))) + + +EntryClass = Cosmos2VideoToWorldPipeline diff --git a/fastvideo/pipelines/pipeline_registry.py b/fastvideo/pipelines/pipeline_registry.py index 2f56d06c2..e6678e1cd 100644 --- a/fastvideo/pipelines/pipeline_registry.py +++ b/fastvideo/pipelines/pipeline_registry.py @@ -25,6 +25,7 @@ "WanCausalDMDPipeline": "wan", "StepVideoPipeline": "stepvideo", "HunyuanVideoPipeline": "hunyuan", + "Cosmos2VideoToWorldPipeline": "cosmos" } _PREPROCESS_WORKLOAD_TYPE_TO_PIPELINE_NAME: dict[WorkloadType, str] = { diff --git a/fastvideo/pipelines/stages/__init__.py b/fastvideo/pipelines/stages/__init__.py index 2880db3f2..7371f955d 100644 --- a/fastvideo/pipelines/stages/__init__.py +++ b/fastvideo/pipelines/stages/__init__.py @@ -10,7 +10,8 @@ from fastvideo.pipelines.stages.causal_denoising import CausalDMDDenosingStage from fastvideo.pipelines.stages.conditioning import ConditioningStage from fastvideo.pipelines.stages.decoding import DecodingStage -from fastvideo.pipelines.stages.denoising import (DenoisingStage, +from fastvideo.pipelines.stages.denoising import (CosmosDenoisingStage, + DenoisingStage, DmdDenoisingStage) from fastvideo.pipelines.stages.encoding import EncodingStage from fastvideo.pipelines.stages.image_encoding import (ImageEncodingStage, @@ -18,7 +19,8 @@ ImageVAEEncodingStage, VideoVAEEncodingStage) from fastvideo.pipelines.stages.input_validation import InputValidationStage -from fastvideo.pipelines.stages.latent_preparation import LatentPreparationStage +from fastvideo.pipelines.stages.latent_preparation import ( + CosmosLatentPreparationStage, LatentPreparationStage) from fastvideo.pipelines.stages.stepvideo_encoding import ( StepvideoPromptEncodingStage) from fastvideo.pipelines.stages.text_encoding import TextEncodingStage @@ -30,10 +32,12 @@ "InputValidationStage", "TimestepPreparationStage", "LatentPreparationStage", + "CosmosLatentPreparationStage", "ConditioningStage", "DenoisingStage", "DmdDenoisingStage", "CausalDMDDenosingStage", + "CosmosDenoisingStage", "EncodingStage", "DecodingStage", "ImageEncodingStage", diff --git a/fastvideo/pipelines/stages/decoding.py b/fastvideo/pipelines/stages/decoding.py index d75da2c93..830dbf34b 100644 --- a/fastvideo/pipelines/stages/decoding.py +++ b/fastvideo/pipelines/stages/decoding.py @@ -76,11 +76,12 @@ def decode(self, latents: torch.Tensor, vae_autocast_enabled = ( vae_dtype != torch.float32) and not fastvideo_args.disable_autocast - if isinstance(self.vae.scaling_factor, torch.Tensor): - latents = latents / self.vae.scaling_factor.to( - latents.device, latents.dtype) - else: - latents = latents / self.vae.scaling_factor + if hasattr(self.vae, 'scaling_factor'): + if isinstance(self.vae.scaling_factor, torch.Tensor): + latents = latents / self.vae.scaling_factor.to( + latents.device, latents.dtype) + else: + latents = latents / self.vae.scaling_factor # Apply shifting if needed if (hasattr(self.vae, "shift_factor") diff --git a/fastvideo/pipelines/stages/denoising.py b/fastvideo/pipelines/stages/denoising.py index e0510d6d9..1f231ebaa 100644 --- a/fastvideo/pipelines/stages/denoising.py +++ b/fastvideo/pipelines/stages/denoising.py @@ -403,16 +403,13 @@ def forward( **pos_cond_kwargs, ) - # Apply guidance if batch.do_classifier_free_guidance: batch.is_cfg_negative = True with set_forward_context( current_timestep=i, attn_metadata=attn_metadata, forward_batch=batch, - # fastvideo_args=fastvideo_args ): - # Run transformer noise_pred_uncond = current_model( latent_model_input, neg_prompt_embeds, @@ -421,6 +418,7 @@ def forward( **image_kwargs, **neg_cond_kwargs, ) + noise_pred_text = noise_pred noise_pred = noise_pred_uncond + current_guidance_scale * ( noise_pred_text - noise_pred_uncond) @@ -739,6 +737,289 @@ def verify_output(self, batch: ForwardBatch, return result +class CosmosDenoisingStage(DenoisingStage): + """ + Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler. + """ + + def __init__(self, transformer, scheduler, pipeline=None) -> None: + super().__init__(transformer, scheduler, pipeline) + + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> ForwardBatch: + pipeline = self.pipeline() if self.pipeline else None + if not fastvideo_args.model_loaded["transformer"]: + loader = TransformerLoader() + self.transformer = loader.load( + fastvideo_args.model_paths["transformer"], fastvideo_args) + if pipeline: + pipeline.add_module("transformer", self.transformer) + fastvideo_args.model_loaded["transformer"] = True + + extra_step_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.step, + { + "generator": batch.generator, + "eta": batch.eta + }, + ) + + if hasattr(self.transformer, 'module'): + transformer_dtype = next(self.transformer.module.parameters()).dtype + else: + transformer_dtype = next(self.transformer.parameters()).dtype + target_dtype = transformer_dtype + autocast_enabled = (target_dtype != torch.float32 + ) and not fastvideo_args.disable_autocast + + latents = batch.latents + num_inference_steps = batch.num_inference_steps + guidance_scale = batch.guidance_scale + + sigma_max = 80.0 + sigma_min = 0.002 + sigma_data = 1.0 + final_sigmas_type = "sigma_min" + + if self.scheduler is not None: + self.scheduler.register_to_config( + sigma_max=sigma_max, + sigma_min=sigma_min, + sigma_data=sigma_data, + final_sigmas_type=final_sigmas_type, + ) + + self.scheduler.set_timesteps(num_inference_steps, device=latents.device) + timesteps = self.scheduler.timesteps + + if (hasattr(self.scheduler.config, 'final_sigmas_type') + and self.scheduler.config.final_sigmas_type == "sigma_min" + and len(self.scheduler.sigmas) > 1): + self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2] + + conditioning_latents = getattr(batch, 'conditioning_latents', None) + unconditioning_latents = conditioning_latents + + # Sampling loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if hasattr(self, 'interrupt') and self.interrupt: + continue + + current_sigma = self.scheduler.sigmas[i] + current_t = current_sigma / (current_sigma + 1) + c_in = 1 - current_t + c_skip = 1 - current_t + c_out = -current_t + + timestep = current_t.view(1, 1, 1, 1, + 1).expand(latents.size(0), -1, + latents.size(2), -1, + -1) # [B, 1, T, 1, 1] + + with torch.autocast(device_type="cuda", + dtype=target_dtype, + enabled=autocast_enabled): + + # Conditional forward pass + cond_latent = latents * c_in + + if hasattr( + batch, 'cond_indicator' + ) and batch.cond_indicator is not None and conditioning_latents is not None: + cond_latent = batch.cond_indicator * conditioning_latents + ( + 1 - batch.cond_indicator) * cond_latent + else: + logger.warning( + "Step %s: Missing conditioning data - cond_indicator: %s, conditioning_latents: %s", + i, hasattr(batch, 'cond_indicator'), + conditioning_latents is not None) + + cond_latent = cond_latent.to(target_dtype) + + # Apply conditional timestep processing + cond_timestep = timestep + if hasattr(batch, 'cond_indicator' + ) and batch.cond_indicator is not None: + sigma_conditioning = 0.0001 + t_conditioning = sigma_conditioning / ( + sigma_conditioning + 1) + cond_timestep = batch.cond_indicator * t_conditioning + ( + 1 - batch.cond_indicator) * timestep + cond_timestep = cond_timestep.to(target_dtype) + + with set_forward_context( + current_timestep=i, + attn_metadata=None, + forward_batch=batch, + ): + # Use conditioning masks from CosmosLatentPreparationStage + condition_mask = batch.cond_mask.to( + target_dtype) if hasattr(batch, + 'cond_mask') else None + padding_mask = torch.zeros(1, + 1, + batch.height, + batch.width, + device=cond_latent.device, + dtype=target_dtype) + + # Fallback if masks not available + if condition_mask is None: + batch_size, num_channels, num_frames, height, width = cond_latent.shape + condition_mask = torch.zeros( + batch_size, + 1, + num_frames, + height, + width, + device=cond_latent.device, + dtype=target_dtype) + + noise_pred = self.transformer( + hidden_states=cond_latent, + timestep=cond_timestep.to(target_dtype), + encoder_hidden_states=batch.prompt_embeds[0].to( + target_dtype), + fps=24, # TODO: get fps from batch or config + condition_mask=condition_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + + cond_pred = (c_skip * latents + + c_out * noise_pred.float()).to(target_dtype) + + if hasattr( + batch, 'cond_indicator' + ) and batch.cond_indicator is not None and conditioning_latents is not None: + cond_pred = batch.cond_indicator * conditioning_latents + ( + 1 - batch.cond_indicator) * cond_pred + + if batch.do_classifier_free_guidance and batch.negative_prompt_embeds is not None: + uncond_latent = latents * c_in + + if hasattr( + batch, 'uncond_indicator' + ) and batch.uncond_indicator is not None and unconditioning_latents is not None: + uncond_latent = batch.uncond_indicator * unconditioning_latents + ( + 1 - batch.uncond_indicator) * uncond_latent + + with set_forward_context( + current_timestep=i, + attn_metadata=None, + forward_batch=batch, + ): + # Use uncond_mask for unconditional pass if available + uncond_condition_mask = batch.uncond_mask.to( + target_dtype + ) if hasattr( + batch, 'uncond_mask' + ) and batch.uncond_mask is not None else condition_mask + + # Apply same conditional timestep processing for unconditional pass + uncond_timestep = timestep + if hasattr(batch, 'uncond_indicator' + ) and batch.uncond_indicator is not None: + sigma_conditioning = 0.0001 # Same as Diffusers default + t_conditioning = sigma_conditioning / ( + sigma_conditioning + 1) + uncond_timestep = batch.uncond_indicator * t_conditioning + ( + 1 - batch.uncond_indicator) * timestep + uncond_timestep = uncond_timestep.to( + target_dtype) + + noise_pred_uncond = self.transformer( + hidden_states=uncond_latent.to(target_dtype), + timestep=uncond_timestep.to(target_dtype), + encoder_hidden_states=batch. + negative_prompt_embeds[0].to(target_dtype), + fps=24, # TODO: get fps from batch or config + condition_mask=uncond_condition_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + + uncond_pred = ( + c_skip * latents + + c_out * noise_pred_uncond.float()).to(target_dtype) + + # Apply conditional indicator masking for unconditional prediction like diffusers + if hasattr( + batch, 'uncond_indicator' + ) and batch.uncond_indicator is not None and unconditioning_latents is not None: + uncond_pred = batch.uncond_indicator * unconditioning_latents + ( + 1 - batch.uncond_indicator) * uncond_pred + + guidance_diff = cond_pred - uncond_pred + final_pred = cond_pred + guidance_scale * guidance_diff + else: + final_pred = cond_pred + + # Convert to noise for scheduler step + if current_sigma > 1e-8: + noise_for_scheduler = (latents - final_pred) / current_sigma + else: + logger.warning( + "Step %s: current_sigma too small (%s), using final_pred directly", + i, current_sigma) + noise_for_scheduler = final_pred + + # Debug: Check for NaN values before scheduler step + if torch.isnan(noise_for_scheduler).sum() > 0: + logger.error( + "Step %s: NaN detected in noise_for_scheduler, sum: %s", + i, + noise_for_scheduler.float().sum().item()) + logger.error( + "Step %s: latents sum: %s, final_pred sum: %s, current_sigma: %s", + i, + latents.float().sum().item(), + final_pred.float().sum().item(), current_sigma) + + latents = self.scheduler.step(noise_for_scheduler, + t, + latents, + **extra_step_kwargs, + return_dict=False)[0] + + progress_bar.update() + + # Update batch with final latents + batch.latents = latents + + return batch + + def verify_input(self, batch: ForwardBatch, + fastvideo_args: FastVideoArgs) -> VerificationResult: + """Verify Cosmos denoising stage inputs.""" + result = VerificationResult() + result.add_check("latents", batch.latents, + [V.is_tensor, V.with_dims(5)]) + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty) + result.add_check("num_inference_steps", batch.num_inference_steps, + V.positive_int) + result.add_check("guidance_scale", batch.guidance_scale, + V.positive_float) + result.add_check("do_classifier_free_guidance", + batch.do_classifier_free_guidance, V.bool_value) + result.add_check( + "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x: + not batch.do_classifier_free_guidance or V.list_not_empty(x)) + return result + + def verify_output(self, batch: ForwardBatch, + fastvideo_args: FastVideoArgs) -> VerificationResult: + """Verify Cosmos denoising stage outputs.""" + result = VerificationResult() + result.add_check("latents", batch.latents, + [V.is_tensor, V.with_dims(5)]) + return result + + class DmdDenoisingStage(DenoisingStage): """ Denoising stage for DMD. diff --git a/fastvideo/pipelines/stages/input_validation.py b/fastvideo/pipelines/stages/input_validation.py index 5b06e968e..62f6c06c7 100644 --- a/fastvideo/pipelines/stages/input_validation.py +++ b/fastvideo/pipelines/stages/input_validation.py @@ -41,7 +41,7 @@ def _generate_seeds(self, batch: ForwardBatch, batch.seeds = seeds # Peiyuan: using GPU seed will cause A100 and H100 to generate different results... batch.generator = [ - torch.Generator("cpu").manual_seed(seed) for seed in seeds + torch.Generator(device="cpu").manual_seed(seed) for seed in seeds ] def forward( diff --git a/fastvideo/pipelines/stages/latent_preparation.py b/fastvideo/pipelines/stages/latent_preparation.py index ea23a7daa..0cf15d7d8 100644 --- a/fastvideo/pipelines/stages/latent_preparation.py +++ b/fastvideo/pipelines/stages/latent_preparation.py @@ -3,10 +3,14 @@ Latent preparation stage for diffusion pipelines. """ +from typing import Any + +import torch from diffusers.utils.torch_utils import randn_tensor from fastvideo.distributed import get_local_torch_device from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.image_processor import ImageProcessor from fastvideo.logger import init_logger from fastvideo.pipelines.pipeline_batch_info import ForwardBatch from fastvideo.pipelines.stages.base import PipelineStage @@ -108,6 +112,238 @@ def forward( return batch + +class CosmosLatentPreparationStage(PipelineStage): + """ + Cosmos-specific latent preparation stage that properly handles the tensor shapes + and conditioning masks required by the Cosmos transformer. + + This stage replicates the logic from diffusers' Cosmos2VideoToWorldPipeline.prepare_latents() + """ + + def __init__(self, scheduler, transformer, vae=None) -> None: + super().__init__() + self.scheduler = scheduler + self.transformer = transformer + self.vae = vae + + def forward( + self, + batch: ForwardBatch, + fastvideo_args: FastVideoArgs, + ) -> ForwardBatch: + # Determine batch size + if isinstance(batch.prompt, list): + batch_size = len(batch.prompt) + elif batch.prompt is not None: + batch_size = 1 + else: + batch_size = batch.prompt_embeds[0].shape[0] + + # Adjust batch size for number of videos per prompt + batch_size *= batch.num_videos_per_prompt + + # Get required parameters + # Force float32 for latent preparation + dtype = torch.float32 + device = get_local_torch_device() + generator = batch.generator + latents = batch.latents + num_frames = batch.num_frames + height = batch.height + width = batch.width + + if height is None or width is None: + raise ValueError("Height and width must be provided") + + vae_scale_factor_spatial = 8 + vae_scale_factor_temporal = 4 + + latent_height = height // 8 + latent_width = width // vae_scale_factor_spatial + + num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 + + # Cosmos transformer expects in_channels - 1 for the latent channels + num_channels_latents = self.transformer.config.in_channels - 1 + + shape = (batch_size, num_channels_latents, num_latent_frames, + latent_height, latent_width) + + init_latents = None + conditioning_latents = None + + video = None + + if hasattr(batch, 'video') and batch.video is not None: + video = batch.video + elif hasattr(batch, 'pil_image') and batch.pil_image is not None: + vae_scale_factor_spatial = 8 + image_processor = ImageProcessor( + vae_scale_factor=vae_scale_factor_spatial) + + processed_image = image_processor.preprocess( + batch.pil_image, height, width) + + # Add time dimension + video = processed_image.unsqueeze(2) + + video = video.to(device=device, dtype=torch.bfloat16) + elif hasattr( + batch, + 'preprocessed_image') and batch.preprocessed_image is not None: + # Convert preprocessed image to video format + if isinstance(batch.preprocessed_image, torch.Tensor): + if batch.preprocessed_image.dim( + ) == 4: # [B, C, H, W] -> [B, C, T, H, W] + video = batch.preprocessed_image.unsqueeze(2) + elif batch.preprocessed_image.dim( + ) == 5: # Already [B, C, T, H, W] + video = batch.preprocessed_image + else: + logger.info( + "CosmosLatentPreparationStage - No video input sources found") + + if video is not None: + num_cond_frames = video.size(2) + + if num_cond_frames >= num_frames: + # Take the last `num_frames` frames for conditioning + num_cond_latent_frames = (num_frames - + 1) // vae_scale_factor_temporal + 1 + video = video[:, :, -num_frames:] + else: + num_cond_latent_frames = (num_cond_frames - + 1) // vae_scale_factor_temporal + 1 + num_padding_frames = num_frames - num_cond_frames + last_frame = video[:, :, -1:] + padding = last_frame.repeat(1, 1, num_padding_frames, 1, 1) + video = torch.cat([video, padding], dim=2) + + if self.vae is not None: + # Move VAE to correct device before encoding + self.vae = self.vae.to(device) + self.vae = self.vae.to(dtype=video.dtype) + + def retrieve_latents( + encoder_output: Any, + generator: Any | None = None) -> torch.Tensor: + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + elif hasattr(encoder_output, "sample"): + return encoder_output.sample(generator) + elif isinstance(encoder_output, torch.Tensor): + return encoder_output + else: + attrs = [ + attr for attr in dir(encoder_output) + if not attr.startswith('_') + ] + raise AttributeError( + f"Could not access latents of provided encoder_output. Available attributes: {attrs}" + ) + + if isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), + generator=torch.Generator( + device="cpu").manual_seed(100)) + for i in range(batch_size) + ] + else: + init_latents = [ + retrieve_latents( + self.vae.encode(vid.unsqueeze(0)), + torch.Generator(device="cpu").manual_seed(100)) + for vid in video + ] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + + # Apply latent normalization + if hasattr(self.vae.config, 'latents_mean') and hasattr( + self.vae.config, 'latents_std'): + latents_mean = torch.tensor( + self.vae.config.latents_mean).view( + 1, self.vae.config.z_dim, 1, 1, + 1).to(device, dtype) + latents_std = torch.tensor( + self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, + 1).to(device, dtype) + init_latents = (init_latents - latents_mean + ) / latents_std * self.scheduler.sigma_data + + conditioning_latents = init_latents + + # Offload VAE to CPU after encoding to save memory + self.vae.to("cpu") + else: + num_cond_latent_frames = 0 + + # Generate or use provided latents + if latents is None: + # Use float32 for randn_tensor + latents = randn_tensor( + shape, + generator=torch.Generator(device="cpu").manual_seed(200), + device=device, + dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latents = latents * self.scheduler.sigma_max + + padding_shape = (batch_size, 1, num_latent_frames, latent_height, + latent_width) + ones_padding = latents.new_ones(padding_shape) + zeros_padding = latents.new_zeros(padding_shape) + + cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + cond_indicator[:, :, :num_cond_latent_frames] = 1.0 + cond_mask = cond_indicator * ones_padding + ( + 1 - cond_indicator) * zeros_padding + + uncond_indicator = None + uncond_mask = None + if batch.do_classifier_free_guidance: + uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + uncond_indicator[:, :, :num_cond_latent_frames] = 1.0 + uncond_mask = uncond_indicator * ones_padding + ( + 1 - uncond_indicator) * zeros_padding + + batch.latents = latents + batch.raw_latent_shape = latents.shape + + batch.conditioning_latents = conditioning_latents + batch.cond_indicator = cond_indicator + batch.uncond_indicator = uncond_indicator + batch.cond_mask = cond_mask + batch.uncond_mask = uncond_mask + + return batch + + def verify_input(self, batch: ForwardBatch, + fastvideo_args: FastVideoArgs) -> VerificationResult: + """Verify Cosmos latent preparation stage inputs.""" + result = VerificationResult() + result.add_check( + "prompt_or_embeds", None, lambda _: V.string_or_list_strings( + batch.prompt) or V.list_not_empty(batch.prompt_embeds)) + result.add_check("prompt_embeds", batch.prompt_embeds, + V.list_of_tensors) + result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt, + V.positive_int) + result.add_check("generator", batch.generator, + V.generator_or_list_generators) + result.add_check("num_frames", batch.num_frames, V.positive_int) + result.add_check("height", batch.height, V.positive_int) + result.add_check("width", batch.width, V.positive_int) + result.add_check("latents", batch.latents, V.none_or_tensor) + return result + def adjust_video_length(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> int: """ @@ -130,25 +366,6 @@ def adjust_video_length(self, batch: ForwardBatch, latent_num_frames = video_length // 17 * 3 return int(latent_num_frames) - def verify_input(self, batch: ForwardBatch, - fastvideo_args: FastVideoArgs) -> VerificationResult: - """Verify latent preparation stage inputs.""" - result = VerificationResult() - result.add_check( - "prompt_or_embeds", None, lambda _: V.string_or_list_strings( - batch.prompt) or V.list_not_empty(batch.prompt_embeds)) - result.add_check("prompt_embeds", batch.prompt_embeds, - V.list_of_tensors) - result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt, - V.positive_int) - result.add_check("generator", batch.generator, - V.generator_or_list_generators) - result.add_check("num_frames", batch.num_frames, V.positive_int) - result.add_check("height", batch.height, V.positive_int) - result.add_check("width", batch.width, V.positive_int) - result.add_check("latents", batch.latents, V.none_or_tensor) - return result - def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult: """Verify latent preparation stage outputs.""" diff --git a/fastvideo/pipelines/stages/text_encoding.py b/fastvideo/pipelines/stages/text_encoding.py index 4353f1b26..b23671775 100644 --- a/fastvideo/pipelines/stages/text_encoding.py +++ b/fastvideo/pipelines/stages/text_encoding.py @@ -69,6 +69,13 @@ def forward( encoder_index=all_indices, return_attention_mask=True, ) + + # Zero out embeddings beyond actual sequence length + for prompt_embeds, attention_mask in zip(prompt_embeds_list, prompt_masks_list): + lengths = attention_mask.sum(dim=1).cpu() + for i, length in enumerate(lengths): + prompt_embeds[i, length:] = 0 + for pe in prompt_embeds_list: batch.prompt_embeds.append(pe) if batch.prompt_attention_mask is not None: @@ -84,6 +91,12 @@ def forward( encoder_index=all_indices, return_attention_mask=True, ) + + for neg_embeds, neg_mask in zip(neg_embeds_list, neg_masks_list): + lengths = neg_mask.sum(dim=1).cpu() + for i, length in enumerate(lengths): + neg_embeds[i, length:] = 0 + assert batch.negative_prompt_embeds is not None for ne in neg_embeds_list: batch.negative_prompt_embeds.append(ne) diff --git a/fastvideo/pipelines/stages/utils.py b/fastvideo/pipelines/stages/utils.py new file mode 100644 index 000000000..c7c272ab8 --- /dev/null +++ b/fastvideo/pipelines/stages/utils.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Utility functions for pipeline stages. +""" + +import inspect +from typing import Any + +import torch + + +def retrieve_timesteps( + scheduler: Any, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs: Any, +) -> tuple[Any, int]: + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + if timesteps is None: + raise ValueError("scheduler.timesteps is None after set_timesteps") + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + if timesteps is None: + raise ValueError("scheduler.timesteps is None after set_timesteps") + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + if timesteps is None: + raise ValueError("scheduler.timesteps is None after set_timesteps") + num_inference_steps = len(timesteps) + return timesteps, num_inference_steps diff --git a/fastvideo/tests/encoders/test_t5_encoder.py b/fastvideo/tests/encoders/test_t5_encoder.py index fa75d3691..0e93b53bb 100644 --- a/fastvideo/tests/encoders/test_t5_encoder.py +++ b/fastvideo/tests/encoders/test_t5_encoder.py @@ -6,7 +6,7 @@ import torch from torch.distributed.tensor import DTensor from torch.testing import assert_close -from transformers import AutoConfig, AutoTokenizer, UMT5EncoderModel +from transformers import AutoConfig, AutoTokenizer, UMT5EncoderModel, T5EncoderModel from fastvideo.configs.pipelines import PipelineConfig from fastvideo.forward_context import set_forward_context @@ -14,14 +14,15 @@ from fastvideo.models.loader.component_loader import TextEncoderLoader from fastvideo.utils import maybe_download_model, PRECISION_TO_TYPE from fastvideo.fastvideo_args import FastVideoArgs -from fastvideo.configs.models.encoders import T5Config +from fastvideo.configs.models.encoders import T5Config, T5LargeConfig logger = init_logger(__name__) os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29503" -BASE_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" +#BASE_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" +BASE_MODEL_PATH = "nvidia/Cosmos-Predict2-2B-Video2World" MODEL_PATH = maybe_download_model(BASE_MODEL_PATH, local_dir=os.path.join( 'data', BASE_MODEL_PATH)) @@ -118,6 +119,123 @@ def test_t5_encoder(): last_hidden_state1 = outputs1[tokens.attention_mask == 1] last_hidden_state2 = outputs2[tokens.attention_mask == 1] + assert last_hidden_state1.shape == last_hidden_state2.shape, \ + f"Hidden state shapes don't match: {last_hidden_state1.shape} vs {last_hidden_state2.shape}" + + max_diff_hidden = torch.max( + torch.abs(last_hidden_state1 - last_hidden_state2)) + mean_diff_hidden = torch.mean( + torch.abs(last_hidden_state1 - last_hidden_state2)) + + logger.info("Maximum difference in last hidden states: %s", + max_diff_hidden.item()) + logger.info("Mean difference in last hidden states: %s", + mean_diff_hidden.item()) + logger.info("Max memory allocated: %s GB", torch.cuda.max_memory_allocated() / 1024**3) + # Check if outputs are similar (allowing for small numerical differences) + assert mean_diff_hidden < 1e-4, \ + f"Hidden states differ significantly: mean diff = {mean_diff_hidden.item()}" + assert max_diff_hidden < 1e-4, \ + f"Hidden states differ significantly: max diff = {max_diff_hidden.item()}" + + +@pytest.mark.usefixtures("distributed_setup") +def test_t5_large_encoder(): + # Initialize the two model implementations + hf_config = AutoConfig.from_pretrained(TEXT_ENCODER_PATH) + print(hf_config) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + precision_str = "fp32" + precision = PRECISION_TO_TYPE[precision_str] + model1 = T5EncoderModel.from_pretrained(TEXT_ENCODER_PATH).to( + precision).to(device).eval() + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) + + args = FastVideoArgs(model_path=TEXT_ENCODER_PATH, + pipeline_config=PipelineConfig(text_encoder_configs=(T5LargeConfig(),), + text_encoder_precisions=(precision_str,)), + pin_cpu_memory=False) + loader = TextEncoderLoader() + model2 = loader.load(TEXT_ENCODER_PATH, args) + model2 = model2.to(precision) + model2.eval() + + # Sanity check weights between the two models + logger.info("Comparing model weights for sanity check...") + params1 = dict(model1.named_parameters()) + params2 = dict(model2.named_parameters()) + + # Check number of parameters + logger.info("Model1 has %s parameters", len(params1)) + logger.info("Model2 has %s parameters", len(params2)) + + # # Print parameter names for comparison + # logger.info("Model1 parameters:") + # for name in sorted(params1.keys()): + # logger.info(" %s: %s", name, params1[name].shape) + + # logger.info("Model2 parameters:") + # for name in sorted(params2.keys()): + # logger.info(" %s: %s", name, params2[name].shape) + + weight_diffs = [] + # check if embed_tokens are the same + # weights = ["encoder.block.{}.layer.0.layer_norm.weight", "encoder.block.{}.layer.0.SelfAttention.relative_attention_bias.weight", \ + # "encoder.block.{}.layer.0.SelfAttention.o.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_0.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_1.weight",\ + # "encoder.block.{}.layer.1.DenseReluDense.wo.weight", \ + # "encoder.block.{}.layer.1.layer_norm.weight", "encoder.final_layer_norm.weight"] + + # for idx in range(hf_config.num_hidden_layers): + # for w in weights: + # name1 = w.format(idx) + # name2 = w.format(idx) + # p1 = params1[name1] + # p2 = params2[name2] + # p2 = (p2.to_local() if isinstance(p2, DTensor) else p2).to(p1) + # assert_close(p1, p2, atol=1e-4, rtol=1e-4) + + + # Test with some sample prompts + prompts = [ + "Once upon a time", "The quick brown fox jumps over", + "In a galaxy far, far away" + ] + + logger.info("Testing T5 Large encoder with sample prompts") + + with torch.no_grad(): + for prompt in prompts: + logger.info("Testing prompt: %s", prompt) + + # Tokenize the prompt + tokens = tokenizer(prompt, + padding="max_length", + max_length=512, + truncation=True, + return_tensors="pt").to(device) + + # Get outputs from HuggingFace implementation + # filter out padding input_ids + # tokens.input_ids = tokens.input_ids[tokens.attention_mask==1] + # tokens.attention_mask = tokens.attention_mask[tokens.attention_mask==1] + outputs1 = model1(input_ids=tokens.input_ids, + attention_mask=tokens.attention_mask, + output_hidden_states=True).last_hidden_state + print("--------------------------------") + logger.info("Testing model2 with T5LargeConfig") + + # Get outputs from our implementation + with set_forward_context(current_timestep=0, attn_metadata=None): + outputs2 = model2( + input_ids=tokens.input_ids, + attention_mask=tokens.attention_mask, + ).last_hidden_state + + # Compare last hidden states + last_hidden_state1 = outputs1[tokens.attention_mask == 1] + last_hidden_state2 = outputs2[tokens.attention_mask == 1] + assert last_hidden_state1.shape == last_hidden_state2.shape, \ f"Hidden state shapes don't match: {last_hidden_state1.shape} vs {last_hidden_state2.shape}" diff --git a/fastvideo/tests/transformers/test_cosmos.py b/fastvideo/tests/transformers/test_cosmos.py new file mode 100644 index 000000000..2b6ca5dbb --- /dev/null +++ b/fastvideo/tests/transformers/test_cosmos.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import numpy as np +import pytest +import torch +from diffusers.models.transformers.transformer_cosmos import CosmosTransformer3DModel + +from fastvideo.configs.pipelines import PipelineConfig +from fastvideo.forward_context import set_forward_context +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger +from fastvideo.models.loader.component_loader import TransformerLoader +from fastvideo.utils import maybe_download_model +from fastvideo.configs.models.dits import CosmosVideoConfig +from fastvideo.pipelines.pipeline_batch_info import ForwardBatch + + +logger = init_logger(__name__) + +os.environ["MASTER_ADDR"] = "localhost" +os.environ["MASTER_PORT"] = "29504" + +# BASE_MODEL_PATH = "nvidia/Cosmos-Predict2-2B-Text2Image" +BASE_MODEL_PATH = "nvidia/Cosmos-Predict2-2B-Video2World" +MODEL_PATH = maybe_download_model(BASE_MODEL_PATH, + local_dir=os.path.join( + 'data', BASE_MODEL_PATH)) +TRANSFORMER_PATH = os.path.join(MODEL_PATH, "transformer") + + +@pytest.mark.usefixtures("distributed_setup") +def test_cosmos2_transformer(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + precision = torch.bfloat16 + precision_str = "bf16" + args = FastVideoArgs(model_path=TRANSFORMER_PATH, + dit_cpu_offload=False, + use_fsdp_inference=False, + pipeline_config=PipelineConfig(dit_config=CosmosVideoConfig(), dit_precision=precision_str)) + + loader = TransformerLoader() + model2 = loader.load(TRANSFORMER_PATH, args).to(device, dtype=precision) + + model1 = CosmosTransformer3DModel.from_pretrained( + TRANSFORMER_PATH, torch_dtype=precision).to(device, dtype=precision).requires_grad_(False) + + total_params = sum(p.numel() for p in model1.parameters()) + # Calculate weight sum for model1 (converting to float64 to avoid overflow) + weight_sum_model1 = sum( + p.to(torch.float64).sum().item() for p in model1.parameters()) + # Also calculate mean for more stable comparison + weight_mean_model1 = weight_sum_model1 / total_params + logger.info("Model 1 weight sum: %s", weight_sum_model1) + logger.info("Model 1 weight mean: %s", weight_mean_model1) + + # Calculate weight sum for model2 (converting to float64 to avoid overflow) + total_params_model2 = sum(p.numel() for p in model2.parameters()) + weight_sum_model2 = sum( + p.to(torch.float64).sum().item() for p in model2.parameters()) + # Also calculate mean for more stable comparison + weight_mean_model2 = weight_sum_model2 / total_params_model2 + logger.info("Model 2 weight sum: %s", weight_sum_model2) + logger.info("Model 2 weight mean: %s", weight_mean_model2) + + weight_sum_diff = abs(weight_sum_model1 - weight_sum_model2) + logger.info("Weight sum difference: %s", weight_sum_diff) + weight_mean_diff = abs(weight_mean_model1 - weight_mean_model2) + logger.info("Weight mean difference: %s", weight_mean_diff) + + # Set both models to eval mode + model1 = model1.eval() + model2 = model2.eval() + + # Create identical inputs for both models + batch_size = 1 + seq_len = 30 + + # Video latents [B, C, T, H, W] - Cosmos2 specific dimensions + hidden_states = torch.randn(batch_size, + 17, + 1, # Single frame for image generation + 32, # Height patches + 32, # Width patches + device=device, + dtype=precision) + + # Text embeddings [B, L, D] - Cosmos2 uses T5 embeddings with 1024 dim + encoder_hidden_states = torch.randn(batch_size, + seq_len, + 1024, # T5 embedding dimension + device=device, + dtype=precision) + + # Timestep + timestep = torch.tensor([500], device=device, dtype=precision) + + # padding mask + padding_mask = hidden_states.new_zeros(1, 1, 32, 32, device=device, dtype=precision) + # print(padding_mask.shape) + + forward_batch = ForwardBatch( + data_type="dummy", + ) + + with torch.autocast('cuda', dtype=precision): + output1 = model1( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + padding_mask=padding_mask, + return_dict=False, + )[0] + with set_forward_context( + current_timestep=0, + attn_metadata=None, + forward_batch=forward_batch, + ): + output2 = model2(hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + padding_mask=padding_mask) + + # Check if outputs have the same shape + assert output1.shape == output2.shape, f"Output shapes don't match: {output1.shape} vs {output2.shape}" + assert output1.dtype == output2.dtype, f"Output dtype don't match: {output1.dtype} vs {output2.dtype}" + + # Check if outputs are similar (allowing for small numerical differences) + max_diff = torch.max(torch.abs(output1 - output2)) + mean_diff = torch.mean(torch.abs(output1 - output2)) + logger.info("Max Diff: %s", max_diff.item()) + logger.info("Mean Diff: %s", mean_diff.item()) + assert max_diff < 1e-1, f"Maximum difference between outputs: {max_diff.item()}" + # mean diff + assert mean_diff < 1e-2, f"Mean difference between outputs: {mean_diff.item()}" + + +@pytest.mark.usefixtures("distributed_setup") +def test_cosmos2_transformer_video2world(): + """Test Cosmos2 Video2World variant""" + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + precision = torch.bfloat16 + precision_str = "bf16" + + # Use Video2World model path + base_model_path = "nvidia/Cosmos-Predict2-2B-Video2World" + model_path = maybe_download_model(base_model_path, + local_dir=os.path.join( + 'data', base_model_path)) + transformer_path = os.path.join(model_path, "transformer") + + # Use torch attention backend for exact diffusers matching + cosmos_config = CosmosVideoConfig() + cosmos_config.arch_config.attention_backend = "torch" + + args = FastVideoArgs(model_path=transformer_path, + dit_cpu_offload=False, + use_fsdp_inference=False, + enable_torch_compile=False, + disable_autocast=True, + pipeline_config=PipelineConfig(dit_config=cosmos_config, dit_precision=precision_str)) + + loader = TransformerLoader() + model2 = loader.load(transformer_path, args).to(device, dtype=precision) + + model1 = CosmosTransformer3DModel.from_pretrained( + transformer_path, torch_dtype=precision).to(device, dtype=precision).requires_grad_(False) + + total_params = sum(p.numel() for p in model1.parameters()) + # Calculate weight sum for model1 (converting to float64 to avoid overflow) + weight_sum_model1 = sum( + p.to(torch.float64).sum().item() for p in model1.parameters()) + # Also calculate mean for more stable comparison + weight_mean_model1 = weight_sum_model1 / total_params + logger.info("Model 1 weight sum: %s", weight_sum_model1) + logger.info("Model 1 weight mean: %s", weight_mean_model1) + + # Calculate weight sum for model2 (converting to float64 to avoid overflow) + total_params_model2 = sum(p.numel() for p in model2.parameters()) + weight_sum_model2 = sum( + p.to(torch.float64).sum().item() for p in model2.parameters()) + # Also calculate mean for more stable comparison + weight_mean_model2 = weight_sum_model2 / total_params_model2 + logger.info("Model 2 weight sum: %s", weight_sum_model2) + logger.info("Model 2 weight mean: %s", weight_mean_model2) + + weight_sum_diff = abs(weight_sum_model1 - weight_sum_model2) + logger.info("Weight sum difference: %s", weight_sum_diff) + weight_mean_diff = abs(weight_mean_model1 - weight_mean_model2) + logger.info("Weight mean difference: %s", weight_mean_diff) + + # Set both models to eval mode + model1 = model1.eval() + model2 = model2.eval() + + # Create identical inputs for both models + batch_size = 1 + seq_len = 30 + + # Video latents [B, C, T, H, W] - Video2World has additional condition channel + hidden_states = torch.randn(batch_size, + 17, # 16 + 1 for condition channel + 8, # Multiple frames for video + 32, # Height patches + 32, # Width patches + device=device, + dtype=precision) + + # Text embeddings [B, L, D] + encoder_hidden_states = torch.randn(batch_size, + seq_len, + 1024, # T5 embedding dimension + device=device, + dtype=precision) + + # Timestep + timestep = torch.tensor([500], device=device, dtype=precision) + + # padding mask + padding_mask = hidden_states.new_zeros(1, 1, 32, 32, device=device, dtype=precision) + + forward_batch = ForwardBatch( + data_type="dummy", + ) + + with torch.autocast('cuda', dtype=precision, enabled=False): + output1 = model1( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + padding_mask=padding_mask, + return_dict=False, + )[0] + with set_forward_context( + current_timestep=0, + attn_metadata=None, + forward_batch=forward_batch, + ): + output2 = model2(hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + padding_mask=padding_mask) + + # Check if outputs have the same shape + assert output1.shape == output2.shape, f"Output shapes don't match: {output1.shape} vs {output2.shape}" + assert output1.dtype == output2.dtype, f"Output dtype don't match: {output1.dtype} vs {output2.dtype}" + + # Check if outputs are similar (allowing for small numerical differences) + max_diff = torch.max(torch.abs(output1 - output2)) + mean_diff = torch.mean(torch.abs(output1 - output2)) + logger.info("Max Diff: %s", max_diff.item()) + logger.info("Mean Diff: %s", mean_diff.item()) + + # With torch attention backend, outputs should now match closely + assert max_diff < 1e-1, f"Maximum difference between outputs: {max_diff.item()}" + # mean diff + assert mean_diff < 1e-2, f"Mean difference between outputs: {mean_diff.item()}" \ No newline at end of file diff --git a/fastvideo/worker/multiproc_executor.py b/fastvideo/worker/multiproc_executor.py index cc8c419af..310f3e7cf 100644 --- a/fastvideo/worker/multiproc_executor.py +++ b/fastvideo/worker/multiproc_executor.py @@ -32,6 +32,13 @@ def _init_executor(self) -> None: self.world_size = self.fastvideo_args.num_gpus self.shutting_down = False + # Initialize CUDA before setting up multiprocessing to ensure + # maybe_force_spawn() correctly detects CUDA and forces 'spawn' mode. + # This prevents "Cannot re-initialize CUDA in forked subprocess" errors. + import torch + if torch.cuda.is_available(): + torch.cuda.init() + set_multiproc_executor_envs() # Check if master_port is provided in fastvideo_args diff --git a/test_fastvideo_pipeline.py b/test_fastvideo_pipeline.py new file mode 100644 index 000000000..2913767fe --- /dev/null +++ b/test_fastvideo_pipeline.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +""" +Simple script to generate a video using the FastVideo generator. +""" + +import os +import sys + +from fastvideo.entrypoints.video_generator import VideoGenerator + + +def generate_video() -> bool: + """Generate a video using the FastVideo generator.""" + + # Configuration + #input_image_path = "/mnt/fast-disks/hao_lab/kevin/FastVideo/tennis.jpg" + #prompt = "A tennis ball bouncing on a racquet, the ball moves in a smooth arc as it hits the strings and rebounds with natural physics. The racquet strings vibrate slightly from the impact, and the ball continues its trajectory with realistic motion." + input_image_path = "/mnt/fast-disks/hao_lab/kevin/FastVideo/yellow-scrubber.png" + prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess." + negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." + output_path = "/mnt/fast-disks/hao_lab/kevin/FastVideo/cosmos2_fastvideo_output.mp4" + + # Check if input image exists + if not os.path.exists(input_image_path): + print(f"Error: Input image not found: {input_image_path}") + return False + + try: + # Create video generator + print("Creating FastVideo generator...") + generator = VideoGenerator.from_pretrained( + model_path="nvidia/Cosmos-Predict2-2B-Video2World", + num_gpus=1, + ) + + print("Generator created successfully") + + # Run inference + print("Generating video...") + result = generator.generate_video(height=704, + width=1280, + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=21, + image_path=input_image_path, + num_inference_steps=35, + guidance_scale=7.0, + seed=1, + save_video=True, + output_path=output_path, + fps=16) + + if result: + print("Video generation completed successfully!") + return True + else: + print("Video generation failed - no result returned") + return False + + except Exception as e: + print(f"Error during video generation: {e}") + import traceback + traceback.print_exc() + return False + + +if __name__ == "__main__": + success = generate_video() + if success: + print("✅ Video generation completed successfully") + sys.exit(0) + else: + print("❌ Video generation failed") + sys.exit(1) From 1b7bfe591b44c2f3108d24aeb47bd0a8b36bdab3 Mon Sep 17 00:00:00 2001 From: kevin314 Date: Sun, 26 Oct 2025 05:11:17 +0000 Subject: [PATCH 2/5] pre-commit --- fastvideo/configs/models/encoders/__init__.py | 3 +-- fastvideo/pipelines/stages/text_encoding.py | 8 ++++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/fastvideo/configs/models/encoders/__init__.py b/fastvideo/configs/models/encoders/__init__.py index 7f49284f7..29b5d6113 100644 --- a/fastvideo/configs/models/encoders/__init__.py +++ b/fastvideo/configs/models/encoders/__init__.py @@ -10,6 +10,5 @@ __all__ = [ "EncoderConfig", "TextEncoderConfig", "ImageEncoderConfig", "BaseEncoderOutput", "CLIPTextConfig", "CLIPVisionConfig", - "WAN2_1ControlCLIPVisionConfig", "LlamaConfig", "T5Config", - "T5LargeConfig" + "WAN2_1ControlCLIPVisionConfig", "LlamaConfig", "T5Config", "T5LargeConfig" ] diff --git a/fastvideo/pipelines/stages/text_encoding.py b/fastvideo/pipelines/stages/text_encoding.py index b23671775..0edde936a 100644 --- a/fastvideo/pipelines/stages/text_encoding.py +++ b/fastvideo/pipelines/stages/text_encoding.py @@ -71,7 +71,9 @@ def forward( ) # Zero out embeddings beyond actual sequence length - for prompt_embeds, attention_mask in zip(prompt_embeds_list, prompt_masks_list): + for prompt_embeds, attention_mask in zip(prompt_embeds_list, + prompt_masks_list, + strict=False): lengths = attention_mask.sum(dim=1).cpu() for i, length in enumerate(lengths): prompt_embeds[i, length:] = 0 @@ -92,7 +94,9 @@ def forward( return_attention_mask=True, ) - for neg_embeds, neg_mask in zip(neg_embeds_list, neg_masks_list): + for neg_embeds, neg_mask in zip(neg_embeds_list, + neg_masks_list, + strict=False): lengths = neg_mask.sum(dim=1).cpu() for i, length in enumerate(lengths): neg_embeds[i, length:] = 0 From cb657093b70b5eabe154b299d546bfdd3601cfa5 Mon Sep 17 00:00:00 2001 From: kevin314 Date: Mon, 27 Oct 2025 09:43:30 +0000 Subject: [PATCH 3/5] Fix ci issues --- .buildkite/scripts/pr_test.sh | 10 +- fastvideo/layers/layernorm.py | 16 --- .../pipelines/stages/latent_preparation.py | 88 ++++++++++--- fastvideo/tests/encoders/test_t5_encoder.py | 99 ++++++++------ fastvideo/tests/modal/pr_test.py | 12 +- fastvideo/tests/transformers/test_cosmos.py | 123 ------------------ 6 files changed, 136 insertions(+), 212 deletions(-) diff --git a/.buildkite/scripts/pr_test.sh b/.buildkite/scripts/pr_test.sh index 7819ceb2a..93704a9b0 100755 --- a/.buildkite/scripts/pr_test.sh +++ b/.buildkite/scripts/pr_test.sh @@ -31,9 +31,9 @@ log "Setting up Modal authentication from Buildkite secrets..." MODAL_TOKEN_ID=$(buildkite-agent secret get modal_token_id) MODAL_TOKEN_SECRET=$(buildkite-agent secret get modal_token_secret) +# Retrieve other secrets WANDB_API_KEY=$(buildkite-agent secret get wandb_api_key) - -WANDB_API_KEY=$(buildkite-agent secret get wandb_api_key) +HF_API_KEY=$(buildkite-agent secret get hf_api_key) if [ -n "$MODAL_TOKEN_ID" ] && [ -n "$MODAL_TOKEN_SECRET" ]; then log "Retrieved Modal credentials from Buildkite secrets" @@ -63,15 +63,15 @@ MODAL_ENV="BUILDKITE_REPO=$BUILDKITE_REPO BUILDKITE_COMMIT=$BUILDKITE_COMMIT BUI case "$TEST_TYPE" in "encoder") log "Running encoder tests..." - MODAL_COMMAND="$MODAL_ENV python3 -m modal run $MODAL_TEST_FILE::run_encoder_tests" + MODAL_COMMAND="$MODAL_ENV HF_API_KEY=$HF_API_KEY python3 -m modal run $MODAL_TEST_FILE::run_encoder_tests" ;; "vae") log "Running VAE tests..." - MODAL_COMMAND="$MODAL_ENV python3 -m modal run $MODAL_TEST_FILE::run_vae_tests" + MODAL_COMMAND="$MODAL_ENV HF_API_KEY=$HF_API_KEY python3 -m modal run $MODAL_TEST_FILE::run_vae_tests" ;; "transformer") log "Running transformer tests..." - MODAL_COMMAND="$MODAL_ENV python3 -m modal run $MODAL_TEST_FILE::run_transformer_tests" + MODAL_COMMAND="$MODAL_ENV HF_API_KEY=$HF_API_KEY python3 -m modal run $MODAL_TEST_FILE::run_transformer_tests" ;; "ssim") log "Running SSIM tests..." diff --git a/fastvideo/layers/layernorm.py b/fastvideo/layers/layernorm.py index 7077418b9..091ab841a 100644 --- a/fastvideo/layers/layernorm.py +++ b/fastvideo/layers/layernorm.py @@ -40,22 +40,6 @@ def __init__( if self.has_weight: self.weight = nn.Parameter(self.weight) - def forward_diffusers(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Forward method that matches Diffusers RMSNorm implementation exactly.""" - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + - self.variance_epsilon) - - if self.has_weight and self.weight is not None: - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - hidden_states = hidden_states * self.weight - else: - hidden_states = hidden_states.to(input_dtype) - - return hidden_states - # if we do fully_shard(model.layer_norm), and we call layer_form.forward_native(input) instead of layer_norm(input), # we need to call model.layer_norm.register_fsdp_forward_method(model, "forward_native") to make sure fsdp2 hooks are triggered # for mixed precision and cpu offloading diff --git a/fastvideo/pipelines/stages/latent_preparation.py b/fastvideo/pipelines/stages/latent_preparation.py index 0cf15d7d8..3cd6a9574 100644 --- a/fastvideo/pipelines/stages/latent_preparation.py +++ b/fastvideo/pipelines/stages/latent_preparation.py @@ -112,6 +112,56 @@ def forward( return batch + def adjust_video_length(self, batch: ForwardBatch, + fastvideo_args: FastVideoArgs) -> int: + """ + Adjust video length based on VAE version. + + Args: + batch: The current batch information. + fastvideo_args: The inference arguments. + + Returns: + The batch with adjusted video length. + """ + + video_length = batch.num_frames + use_temporal_scaling_frames = fastvideo_args.pipeline_config.vae_config.use_temporal_scaling_frames + if use_temporal_scaling_frames: + temporal_scale_factor = fastvideo_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + latent_num_frames = (video_length - 1) // temporal_scale_factor + 1 + else: # stepvideo only + latent_num_frames = video_length // 17 * 3 + return int(latent_num_frames) + + def verify_input(self, batch: ForwardBatch, + fastvideo_args: FastVideoArgs) -> VerificationResult: + """Verify latent preparation stage inputs.""" + result = VerificationResult() + result.add_check( + "prompt_or_embeds", None, lambda _: V.string_or_list_strings( + batch.prompt) or V.list_not_empty(batch.prompt_embeds)) + result.add_check("prompt_embeds", batch.prompt_embeds, + V.list_of_tensors) + result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt, + V.positive_int) + result.add_check("generator", batch.generator, + V.generator_or_list_generators) + result.add_check("num_frames", batch.num_frames, V.positive_int) + result.add_check("height", batch.height, V.positive_int) + result.add_check("width", batch.width, V.positive_int) + result.add_check("latents", batch.latents, V.none_or_tensor) + return result + + def verify_output(self, batch: ForwardBatch, + fastvideo_args: FastVideoArgs) -> VerificationResult: + """Verify latent preparation stage outputs.""" + result = VerificationResult() + result.add_check("latents", batch.latents, + [V.is_tensor, V.with_dims(5)]) + result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple) + return result + class CosmosLatentPreparationStage(PipelineStage): """ @@ -325,25 +375,6 @@ def retrieve_latents( return batch - def verify_input(self, batch: ForwardBatch, - fastvideo_args: FastVideoArgs) -> VerificationResult: - """Verify Cosmos latent preparation stage inputs.""" - result = VerificationResult() - result.add_check( - "prompt_or_embeds", None, lambda _: V.string_or_list_strings( - batch.prompt) or V.list_not_empty(batch.prompt_embeds)) - result.add_check("prompt_embeds", batch.prompt_embeds, - V.list_of_tensors) - result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt, - V.positive_int) - result.add_check("generator", batch.generator, - V.generator_or_list_generators) - result.add_check("num_frames", batch.num_frames, V.positive_int) - result.add_check("height", batch.height, V.positive_int) - result.add_check("width", batch.width, V.positive_int) - result.add_check("latents", batch.latents, V.none_or_tensor) - return result - def adjust_video_length(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> int: """ @@ -366,6 +397,25 @@ def adjust_video_length(self, batch: ForwardBatch, latent_num_frames = video_length // 17 * 3 return int(latent_num_frames) + def verify_input(self, batch: ForwardBatch, + fastvideo_args: FastVideoArgs) -> VerificationResult: + """Verify Cosmos latent preparation stage inputs.""" + result = VerificationResult() + result.add_check( + "prompt_or_embeds", None, lambda _: V.string_or_list_strings( + batch.prompt) or V.list_not_empty(batch.prompt_embeds)) + result.add_check("prompt_embeds", batch.prompt_embeds, + V.list_of_tensors) + result.add_check("num_videos_per_prompt", batch.num_videos_per_prompt, + V.positive_int) + result.add_check("generator", batch.generator, + V.generator_or_list_generators) + result.add_check("num_frames", batch.num_frames, V.positive_int) + result.add_check("height", batch.height, V.positive_int) + result.add_check("width", batch.width, V.positive_int) + result.add_check("latents", batch.latents, V.none_or_tensor) + return result + def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult: """Verify latent preparation stage outputs.""" diff --git a/fastvideo/tests/encoders/test_t5_encoder.py b/fastvideo/tests/encoders/test_t5_encoder.py index 0e93b53bb..27e2bf76c 100644 --- a/fastvideo/tests/encoders/test_t5_encoder.py +++ b/fastvideo/tests/encoders/test_t5_encoder.py @@ -21,35 +21,50 @@ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29503" -#BASE_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" -BASE_MODEL_PATH = "nvidia/Cosmos-Predict2-2B-Video2World" -MODEL_PATH = maybe_download_model(BASE_MODEL_PATH, - local_dir=os.path.join( - 'data', BASE_MODEL_PATH)) -TEXT_ENCODER_PATH = os.path.join(MODEL_PATH, "text_encoder") -TOKENIZER_PATH = os.path.join(MODEL_PATH, "tokenizer") + +@pytest.fixture +def t5_model_paths(): + base_model_path = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + model_path = maybe_download_model(base_model_path, + local_dir=os.path.join( + 'data', base_model_path)) + text_encoder_path = os.path.join(model_path, "text_encoder") + tokenizer_path = os.path.join(model_path, "tokenizer") + return text_encoder_path, tokenizer_path + + +@pytest.fixture +def t5_large_model_paths(): + base_model_path = "nvidia/Cosmos-Predict2-2B-Video2World" + model_path = maybe_download_model(base_model_path, + local_dir=os.path.join( + 'data', base_model_path)) + text_encoder_path = os.path.join(model_path, "text_encoder") + tokenizer_path = os.path.join(model_path, "tokenizer") + return text_encoder_path, tokenizer_path @pytest.mark.usefixtures("distributed_setup") -def test_t5_encoder(): +def test_t5_encoder(t5_model_paths): # Initialize the two model implementations - hf_config = AutoConfig.from_pretrained(TEXT_ENCODER_PATH) + text_encoder_path, tokenizer_path = t5_model_paths + hf_config = AutoConfig.from_pretrained(text_encoder_path) print(hf_config) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") precision_str = "fp32" precision = PRECISION_TO_TYPE[precision_str] - model1 = UMT5EncoderModel.from_pretrained(TEXT_ENCODER_PATH).to( + model1 = UMT5EncoderModel.from_pretrained(text_encoder_path).to( precision).to(device).eval() - tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - args = FastVideoArgs(model_path=TEXT_ENCODER_PATH, + args = FastVideoArgs(model_path=text_encoder_path, pipeline_config=PipelineConfig(text_encoder_configs=(T5Config(),), text_encoder_precisions=(precision_str,)), pin_cpu_memory=False) loader = TextEncoderLoader() - model2 = loader.load(TEXT_ENCODER_PATH, args) + model2 = loader.load(text_encoder_path, args) model2 = model2.to(precision) model2.eval() @@ -62,7 +77,6 @@ def test_t5_encoder(): logger.info("Model1 has %s parameters", len(params1)) logger.info("Model2 has %s parameters", len(params2)) - weight_diffs = [] # check if embed_tokens are the same weights = ["encoder.block.{}.layer.0.layer_norm.weight", "encoder.block.{}.layer.0.SelfAttention.relative_attention_bias.weight", \ "encoder.block.{}.layer.0.SelfAttention.o.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_0.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_1.weight",\ @@ -140,24 +154,25 @@ def test_t5_encoder(): @pytest.mark.usefixtures("distributed_setup") -def test_t5_large_encoder(): +def test_t5_large_encoder(t5_large_model_paths): # Initialize the two model implementations - hf_config = AutoConfig.from_pretrained(TEXT_ENCODER_PATH) + text_encoder_path, tokenizer_path = t5_large_model_paths + hf_config = AutoConfig.from_pretrained(text_encoder_path) print(hf_config) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") precision_str = "fp32" precision = PRECISION_TO_TYPE[precision_str] - model1 = T5EncoderModel.from_pretrained(TEXT_ENCODER_PATH).to( + model1 = T5EncoderModel.from_pretrained(text_encoder_path).to( precision).to(device).eval() - tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - args = FastVideoArgs(model_path=TEXT_ENCODER_PATH, + args = FastVideoArgs(model_path=text_encoder_path, pipeline_config=PipelineConfig(text_encoder_configs=(T5LargeConfig(),), text_encoder_precisions=(precision_str,)), pin_cpu_memory=False) loader = TextEncoderLoader() - model2 = loader.load(TEXT_ENCODER_PATH, args) + model2 = loader.load(text_encoder_path, args) model2 = model2.to(precision) model2.eval() @@ -170,30 +185,28 @@ def test_t5_large_encoder(): logger.info("Model1 has %s parameters", len(params1)) logger.info("Model2 has %s parameters", len(params2)) - # # Print parameter names for comparison - # logger.info("Model1 parameters:") - # for name in sorted(params1.keys()): - # logger.info(" %s: %s", name, params1[name].shape) + # Print parameter names for comparison + logger.info("Model1 parameters:") + for name in sorted(params1.keys()): + logger.info(" %s: %s", name, params1[name].shape) - # logger.info("Model2 parameters:") - # for name in sorted(params2.keys()): - # logger.info(" %s: %s", name, params2[name].shape) - - weight_diffs = [] - # check if embed_tokens are the same - # weights = ["encoder.block.{}.layer.0.layer_norm.weight", "encoder.block.{}.layer.0.SelfAttention.relative_attention_bias.weight", \ - # "encoder.block.{}.layer.0.SelfAttention.o.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_0.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_1.weight",\ - # "encoder.block.{}.layer.1.DenseReluDense.wo.weight", \ - # "encoder.block.{}.layer.1.layer_norm.weight", "encoder.final_layer_norm.weight"] + logger.info("Model2 parameters:") + for name in sorted(params2.keys()): + logger.info(" %s: %s", name, params2[name].shape) + + #check if embed_tokens are the same + weights = ["encoder.block.{}.layer.0.layer_norm.weight", "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", \ + "encoder.block.{}.layer.0.SelfAttention.o.weight", "encoder.block.{}.layer.1.DenseReluDense.wi.weight", \ + "encoder.block.{}.layer.1.DenseReluDense.wo.weight", "encoder.final_layer_norm.weight"] - # for idx in range(hf_config.num_hidden_layers): - # for w in weights: - # name1 = w.format(idx) - # name2 = w.format(idx) - # p1 = params1[name1] - # p2 = params2[name2] - # p2 = (p2.to_local() if isinstance(p2, DTensor) else p2).to(p1) - # assert_close(p1, p2, atol=1e-4, rtol=1e-4) + for idx in range(hf_config.num_hidden_layers): + for w in weights: + name1 = w.format(idx) + name2 = w.format(idx) + p1 = params1[name1] + p2 = params2[name2] + p2 = (p2.to_local() if isinstance(p2, DTensor) else p2).to(p1) + assert_close(p1, p2, atol=1e-4, rtol=1e-4) # Test with some sample prompts @@ -253,4 +266,4 @@ def test_t5_large_encoder(): assert mean_diff_hidden < 1e-4, \ f"Hidden states differ significantly: mean diff = {mean_diff_hidden.item()}" assert max_diff_hidden < 1e-4, \ - f"Hidden states differ significantly: max diff = {max_diff_hidden.item()}" \ No newline at end of file + f"Hidden states differ significantly: max diff = {max_diff_hidden.item()}" diff --git a/fastvideo/tests/modal/pr_test.py b/fastvideo/tests/modal/pr_test.py index 937ab95a3..438c0e441 100644 --- a/fastvideo/tests/modal/pr_test.py +++ b/fastvideo/tests/modal/pr_test.py @@ -62,17 +62,17 @@ def run_test(pytest_command: str): sys.exit(result.returncode) -@app.function(gpu="H100:1", image=image, timeout=900) +@app.function(gpu="H100:1", image=image, timeout=900, secrets=[modal.Secret.from_dict({"HF_API_KEY": os.environ.get("HF_API_KEY", "")})]) def run_encoder_tests(): - run_test("pytest ./fastvideo/tests/encoders -vs") + run_test("hf auth login --token $HF_API_KEY && pytest ./fastvideo/tests/encoders -vs") -@app.function(gpu="L40S:1", image=image, timeout=900) +@app.function(gpu="L40S:1", image=image, timeout=900, secrets=[modal.Secret.from_dict({"HF_API_KEY": os.environ.get("HF_API_KEY", "")})]) def run_vae_tests(): - run_test("pytest ./fastvideo/tests/vaes -vs") + run_test("hf auth login --token $HF_API_KEY && pytest ./fastvideo/tests/vaes -vs") -@app.function(gpu="L40S:1", image=image, timeout=900) +@app.function(gpu="L40S:1", image=image, timeout=900, secrets=[modal.Secret.from_dict({"HF_API_KEY": os.environ.get("HF_API_KEY", "")})]) def run_transformer_tests(): - run_test("pytest ./fastvideo/tests/transformers -vs") + run_test("hf auth login --token $HF_API_KEY && pytest ./fastvideo/tests/transformers -vs") @app.function(gpu="L40S:2", image=image, timeout=2700) def run_ssim_tests(): diff --git a/fastvideo/tests/transformers/test_cosmos.py b/fastvideo/tests/transformers/test_cosmos.py index 2b6ca5dbb..bcb7d51e9 100644 --- a/fastvideo/tests/transformers/test_cosmos.py +++ b/fastvideo/tests/transformers/test_cosmos.py @@ -21,7 +21,6 @@ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29504" -# BASE_MODEL_PATH = "nvidia/Cosmos-Predict2-2B-Text2Image" BASE_MODEL_PATH = "nvidia/Cosmos-Predict2-2B-Video2World" MODEL_PATH = maybe_download_model(BASE_MODEL_PATH, local_dir=os.path.join( @@ -132,126 +131,4 @@ def test_cosmos2_transformer(): logger.info("Mean Diff: %s", mean_diff.item()) assert max_diff < 1e-1, f"Maximum difference between outputs: {max_diff.item()}" # mean diff - assert mean_diff < 1e-2, f"Mean difference between outputs: {mean_diff.item()}" - - -@pytest.mark.usefixtures("distributed_setup") -def test_cosmos2_transformer_video2world(): - """Test Cosmos2 Video2World variant""" - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - precision = torch.bfloat16 - precision_str = "bf16" - - # Use Video2World model path - base_model_path = "nvidia/Cosmos-Predict2-2B-Video2World" - model_path = maybe_download_model(base_model_path, - local_dir=os.path.join( - 'data', base_model_path)) - transformer_path = os.path.join(model_path, "transformer") - - # Use torch attention backend for exact diffusers matching - cosmos_config = CosmosVideoConfig() - cosmos_config.arch_config.attention_backend = "torch" - - args = FastVideoArgs(model_path=transformer_path, - dit_cpu_offload=False, - use_fsdp_inference=False, - enable_torch_compile=False, - disable_autocast=True, - pipeline_config=PipelineConfig(dit_config=cosmos_config, dit_precision=precision_str)) - - loader = TransformerLoader() - model2 = loader.load(transformer_path, args).to(device, dtype=precision) - - model1 = CosmosTransformer3DModel.from_pretrained( - transformer_path, torch_dtype=precision).to(device, dtype=precision).requires_grad_(False) - - total_params = sum(p.numel() for p in model1.parameters()) - # Calculate weight sum for model1 (converting to float64 to avoid overflow) - weight_sum_model1 = sum( - p.to(torch.float64).sum().item() for p in model1.parameters()) - # Also calculate mean for more stable comparison - weight_mean_model1 = weight_sum_model1 / total_params - logger.info("Model 1 weight sum: %s", weight_sum_model1) - logger.info("Model 1 weight mean: %s", weight_mean_model1) - - # Calculate weight sum for model2 (converting to float64 to avoid overflow) - total_params_model2 = sum(p.numel() for p in model2.parameters()) - weight_sum_model2 = sum( - p.to(torch.float64).sum().item() for p in model2.parameters()) - # Also calculate mean for more stable comparison - weight_mean_model2 = weight_sum_model2 / total_params_model2 - logger.info("Model 2 weight sum: %s", weight_sum_model2) - logger.info("Model 2 weight mean: %s", weight_mean_model2) - - weight_sum_diff = abs(weight_sum_model1 - weight_sum_model2) - logger.info("Weight sum difference: %s", weight_sum_diff) - weight_mean_diff = abs(weight_mean_model1 - weight_mean_model2) - logger.info("Weight mean difference: %s", weight_mean_diff) - - # Set both models to eval mode - model1 = model1.eval() - model2 = model2.eval() - - # Create identical inputs for both models - batch_size = 1 - seq_len = 30 - - # Video latents [B, C, T, H, W] - Video2World has additional condition channel - hidden_states = torch.randn(batch_size, - 17, # 16 + 1 for condition channel - 8, # Multiple frames for video - 32, # Height patches - 32, # Width patches - device=device, - dtype=precision) - - # Text embeddings [B, L, D] - encoder_hidden_states = torch.randn(batch_size, - seq_len, - 1024, # T5 embedding dimension - device=device, - dtype=precision) - - # Timestep - timestep = torch.tensor([500], device=device, dtype=precision) - - # padding mask - padding_mask = hidden_states.new_zeros(1, 1, 32, 32, device=device, dtype=precision) - - forward_batch = ForwardBatch( - data_type="dummy", - ) - - with torch.autocast('cuda', dtype=precision, enabled=False): - output1 = model1( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - padding_mask=padding_mask, - return_dict=False, - )[0] - with set_forward_context( - current_timestep=0, - attn_metadata=None, - forward_batch=forward_batch, - ): - output2 = model2(hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - padding_mask=padding_mask) - - # Check if outputs have the same shape - assert output1.shape == output2.shape, f"Output shapes don't match: {output1.shape} vs {output2.shape}" - assert output1.dtype == output2.dtype, f"Output dtype don't match: {output1.dtype} vs {output2.dtype}" - - # Check if outputs are similar (allowing for small numerical differences) - max_diff = torch.max(torch.abs(output1 - output2)) - mean_diff = torch.mean(torch.abs(output1 - output2)) - logger.info("Max Diff: %s", max_diff.item()) - logger.info("Mean Diff: %s", mean_diff.item()) - - # With torch attention backend, outputs should now match closely - assert max_diff < 1e-1, f"Maximum difference between outputs: {max_diff.item()}" - # mean diff assert mean_diff < 1e-2, f"Mean difference between outputs: {mean_diff.item()}" \ No newline at end of file From 01465a63a48d08e191c55897e24f51924058b2fb Mon Sep 17 00:00:00 2001 From: kevin314 Date: Tue, 28 Oct 2025 02:26:37 +0000 Subject: [PATCH 4/5] Move new text encoding logic to cosmos pipeline --- fastvideo/configs/pipelines/cosmos.py | 7 ++ fastvideo/models/encoders/t5.py | 11 ++- fastvideo/pipelines/stages/denoising.py | 10 +-- .../pipelines/stages/input_validation.py | 2 +- fastvideo/pipelines/stages/text_encoding.py | 15 ---- fastvideo/worker/multiproc_executor.py | 7 -- test_fastvideo_pipeline.py | 74 ------------------- 7 files changed, 14 insertions(+), 112 deletions(-) delete mode 100644 test_fastvideo_pipeline.py diff --git a/fastvideo/configs/pipelines/cosmos.py b/fastvideo/configs/pipelines/cosmos.py index 3ca78fe0f..c5f382cc7 100644 --- a/fastvideo/configs/pipelines/cosmos.py +++ b/fastvideo/configs/pipelines/cosmos.py @@ -25,6 +25,13 @@ def t5_large_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor: if nan_count > 0: hidden_state = hidden_state.masked_fill(torch.isnan(hidden_state), 0.0) + # Zero out embeddings beyond actual sequence length + if outputs.attention_mask is not None: + attention_mask = outputs.attention_mask + lengths = attention_mask.sum(dim=1).cpu() + for i, length in enumerate(lengths): + hidden_state[i, length:] = 0 + return hidden_state diff --git a/fastvideo/models/encoders/t5.py b/fastvideo/models/encoders/t5.py index 4a8a1711c..36f012e2e 100644 --- a/fastvideo/models/encoders/t5.py +++ b/fastvideo/models/encoders/t5.py @@ -180,7 +180,6 @@ def __init__(self, self.qkv_proj = QKVParallelLinear( self.d_model, - #self.d_model // self.total_num_heads, self.key_value_proj_dim, self.total_num_heads, self.total_num_kv_heads, @@ -199,7 +198,6 @@ def __init__(self, padding_size=self.relative_attention_num_buckets, quant_config=quant_config) self.o = RowParallelLinear( - #self.d_model, self.total_num_heads * self.key_value_proj_dim, self.d_model, bias=False, @@ -299,12 +297,10 @@ def forward( ) -> torch.Tensor: bs, seq_len, _ = hidden_states.shape num_seqs = bs - #n, c = self.n_heads, self.d_model // self.total_num_heads n, c = self.n_heads, self.key_value_proj_dim qkv, _ = self.qkv_proj(hidden_states) # Projection of 'own' hidden state (self-attention). No GQA here. - #q, k, v = qkv.split(self.inner_dim, dim=-1) - q, k, v = qkv.split(self.qkv_proj.output_sizes, dim=-1) + q, k, v = qkv.split(self.inner_dim, dim=-1) q = q.reshape(bs, seq_len, n, c) k = k.reshape(bs, seq_len, n, c) v = v.reshape(bs, seq_len, n, c) @@ -544,7 +540,10 @@ def forward( attn_metadata=attn_metadata, ) - return BaseEncoderOutput(last_hidden_state=hidden_states) + return BaseEncoderOutput( + last_hidden_state=hidden_states, + attention_mask=attention_mask, + ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/fastvideo/pipelines/stages/denoising.py b/fastvideo/pipelines/stages/denoising.py index 1f231ebaa..00a2b3ea2 100644 --- a/fastvideo/pipelines/stages/denoising.py +++ b/fastvideo/pipelines/stages/denoising.py @@ -803,7 +803,6 @@ def forward( conditioning_latents = getattr(batch, 'conditioning_latents', None) unconditioning_latents = conditioning_latents - # Sampling loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if hasattr(self, 'interrupt') and self.interrupt: @@ -824,7 +823,6 @@ def forward( dtype=target_dtype, enabled=autocast_enabled): - # Conditional forward pass cond_latent = latents * c_in if hasattr( @@ -840,7 +838,6 @@ def forward( cond_latent = cond_latent.to(target_dtype) - # Apply conditional timestep processing cond_timestep = timestep if hasattr(batch, 'cond_indicator' ) and batch.cond_indicator is not None: @@ -913,18 +910,16 @@ def forward( attn_metadata=None, forward_batch=batch, ): - # Use uncond_mask for unconditional pass if available uncond_condition_mask = batch.uncond_mask.to( target_dtype ) if hasattr( batch, 'uncond_mask' ) and batch.uncond_mask is not None else condition_mask - # Apply same conditional timestep processing for unconditional pass uncond_timestep = timestep if hasattr(batch, 'uncond_indicator' ) and batch.uncond_indicator is not None: - sigma_conditioning = 0.0001 # Same as Diffusers default + sigma_conditioning = 0.0001 t_conditioning = sigma_conditioning / ( sigma_conditioning + 1) uncond_timestep = batch.uncond_indicator * t_conditioning + ( @@ -947,7 +942,6 @@ def forward( c_skip * latents + c_out * noise_pred_uncond.float()).to(target_dtype) - # Apply conditional indicator masking for unconditional prediction like diffusers if hasattr( batch, 'uncond_indicator' ) and batch.uncond_indicator is not None and unconditioning_latents is not None: @@ -968,7 +962,6 @@ def forward( i, current_sigma) noise_for_scheduler = final_pred - # Debug: Check for NaN values before scheduler step if torch.isnan(noise_for_scheduler).sum() > 0: logger.error( "Step %s: NaN detected in noise_for_scheduler, sum: %s", @@ -988,7 +981,6 @@ def forward( progress_bar.update() - # Update batch with final latents batch.latents = latents return batch diff --git a/fastvideo/pipelines/stages/input_validation.py b/fastvideo/pipelines/stages/input_validation.py index 62f6c06c7..5b06e968e 100644 --- a/fastvideo/pipelines/stages/input_validation.py +++ b/fastvideo/pipelines/stages/input_validation.py @@ -41,7 +41,7 @@ def _generate_seeds(self, batch: ForwardBatch, batch.seeds = seeds # Peiyuan: using GPU seed will cause A100 and H100 to generate different results... batch.generator = [ - torch.Generator(device="cpu").manual_seed(seed) for seed in seeds + torch.Generator("cpu").manual_seed(seed) for seed in seeds ] def forward( diff --git a/fastvideo/pipelines/stages/text_encoding.py b/fastvideo/pipelines/stages/text_encoding.py index 0edde936a..1e9454de3 100644 --- a/fastvideo/pipelines/stages/text_encoding.py +++ b/fastvideo/pipelines/stages/text_encoding.py @@ -70,14 +70,6 @@ def forward( return_attention_mask=True, ) - # Zero out embeddings beyond actual sequence length - for prompt_embeds, attention_mask in zip(prompt_embeds_list, - prompt_masks_list, - strict=False): - lengths = attention_mask.sum(dim=1).cpu() - for i, length in enumerate(lengths): - prompt_embeds[i, length:] = 0 - for pe in prompt_embeds_list: batch.prompt_embeds.append(pe) if batch.prompt_attention_mask is not None: @@ -94,13 +86,6 @@ def forward( return_attention_mask=True, ) - for neg_embeds, neg_mask in zip(neg_embeds_list, - neg_masks_list, - strict=False): - lengths = neg_mask.sum(dim=1).cpu() - for i, length in enumerate(lengths): - neg_embeds[i, length:] = 0 - assert batch.negative_prompt_embeds is not None for ne in neg_embeds_list: batch.negative_prompt_embeds.append(ne) diff --git a/fastvideo/worker/multiproc_executor.py b/fastvideo/worker/multiproc_executor.py index 310f3e7cf..cc8c419af 100644 --- a/fastvideo/worker/multiproc_executor.py +++ b/fastvideo/worker/multiproc_executor.py @@ -32,13 +32,6 @@ def _init_executor(self) -> None: self.world_size = self.fastvideo_args.num_gpus self.shutting_down = False - # Initialize CUDA before setting up multiprocessing to ensure - # maybe_force_spawn() correctly detects CUDA and forces 'spawn' mode. - # This prevents "Cannot re-initialize CUDA in forked subprocess" errors. - import torch - if torch.cuda.is_available(): - torch.cuda.init() - set_multiproc_executor_envs() # Check if master_port is provided in fastvideo_args diff --git a/test_fastvideo_pipeline.py b/test_fastvideo_pipeline.py deleted file mode 100644 index 2913767fe..000000000 --- a/test_fastvideo_pipeline.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple script to generate a video using the FastVideo generator. -""" - -import os -import sys - -from fastvideo.entrypoints.video_generator import VideoGenerator - - -def generate_video() -> bool: - """Generate a video using the FastVideo generator.""" - - # Configuration - #input_image_path = "/mnt/fast-disks/hao_lab/kevin/FastVideo/tennis.jpg" - #prompt = "A tennis ball bouncing on a racquet, the ball moves in a smooth arc as it hits the strings and rebounds with natural physics. The racquet strings vibrate slightly from the impact, and the ball continues its trajectory with realistic motion." - input_image_path = "/mnt/fast-disks/hao_lab/kevin/FastVideo/yellow-scrubber.png" - prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess." - negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality." - output_path = "/mnt/fast-disks/hao_lab/kevin/FastVideo/cosmos2_fastvideo_output.mp4" - - # Check if input image exists - if not os.path.exists(input_image_path): - print(f"Error: Input image not found: {input_image_path}") - return False - - try: - # Create video generator - print("Creating FastVideo generator...") - generator = VideoGenerator.from_pretrained( - model_path="nvidia/Cosmos-Predict2-2B-Video2World", - num_gpus=1, - ) - - print("Generator created successfully") - - # Run inference - print("Generating video...") - result = generator.generate_video(height=704, - width=1280, - prompt=prompt, - negative_prompt=negative_prompt, - num_frames=21, - image_path=input_image_path, - num_inference_steps=35, - guidance_scale=7.0, - seed=1, - save_video=True, - output_path=output_path, - fps=16) - - if result: - print("Video generation completed successfully!") - return True - else: - print("Video generation failed - no result returned") - return False - - except Exception as e: - print(f"Error during video generation: {e}") - import traceback - traceback.print_exc() - return False - - -if __name__ == "__main__": - success = generate_video() - if success: - print("✅ Video generation completed successfully") - sys.exit(0) - else: - print("❌ Video generation failed") - sys.exit(1) From 01f8359e5a62096b12d30db35eb015590a20bd0a Mon Sep 17 00:00:00 2001 From: kevin314 Date: Tue, 28 Oct 2025 05:04:02 +0000 Subject: [PATCH 5/5] Trigger Build