Skip to content

Commit a4d2086

Browse files
committed
Clean up logs
1 parent 26c2f0f commit a4d2086

File tree

13 files changed

+51
-808
lines changed

13 files changed

+51
-808
lines changed

fastvideo/configs/models/vaes/cosmosvae.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
@dataclass
1010
class CosmosVAEArchConfig(VAEArchConfig):
11-
_class_name: str = "AutoencoderKLWan"
12-
_diffusers_version: str = "0.34.0.dev0"
1311
_name_or_path: str = ""
1412
base_dim: int = 96
1513
z_dim: int = 16
@@ -76,7 +74,8 @@ def __post_init__(self):
7674

7775
@dataclass
7876
class CosmosVAEConfig(VAEConfig):
79-
arch_config: CosmosVAEArchConfig = field(default_factory=CosmosVAEArchConfig)
77+
arch_config: CosmosVAEArchConfig = field(
78+
default_factory=CosmosVAEArchConfig)
8079
use_feature_cache: bool = True
8180

8281
use_tiling: bool = False
@@ -85,4 +84,4 @@ class CosmosVAEConfig(VAEConfig):
8584

8685
def __post_init__(self):
8786
self.blend_num_frames = (self.tile_sample_min_num_frames -
88-
self.tile_sample_stride_num_frames) * 2
87+
self.tile_sample_stride_num_frames) * 2

fastvideo/configs/pipelines/cosmos.py

Lines changed: 10 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -23,90 +23,37 @@ def t5_large_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor:
2323
if hidden_state is None:
2424
raise ValueError("T5 Large outputs missing last_hidden_state")
2525

26-
# Check for NaN values and provide debugging info
2726
nan_count = torch.isnan(hidden_state).sum()
2827
if nan_count > 0:
29-
print(f"WARNING: Found {nan_count} NaN values in T5 Large hidden states")
30-
print(f"Hidden state shape: {hidden_state.shape}")
31-
print(f"Hidden state dtype: {hidden_state.dtype}")
32-
print(f"Hidden state device: {hidden_state.device}")
33-
# Replace NaN values with zeros to avoid pipeline failure
3428
hidden_state = hidden_state.masked_fill(torch.isnan(hidden_state), 0.0)
3529

36-
# Return raw last_hidden_state (no truncation/padding)
3730
return hidden_state
3831

3932

40-
@dataclass
41-
class CosmosVideoConfigFixed(CosmosVideoConfig):
42-
"""Fixed Cosmos Video Config that matches original Cosmos2 Video2World configuration."""
43-
44-
def update_model_arch(self, config: dict) -> None:
45-
"""Update model architecture config with HF config, but fix parameters to match original Cosmos2."""
46-
# First, apply the standard update
47-
super().update_model_arch(config)
48-
49-
# CRITICAL FIXES to match original Cosmos2 Video2World configuration:
50-
51-
# 1. Fix input channels: should be 16 (VAE) + 1 (condition mask) = 17
52-
setattr(self.arch_config, 'in_channels', 17)
53-
54-
# 2. Fix output channels: should be 16 (VAE latent dimension)
55-
setattr(self.arch_config, 'out_channels', 16)
56-
57-
# 3. Fix model architecture to match Cosmos2 2B model
58-
setattr(self.arch_config, 'num_attention_heads', 16)
59-
setattr(self.arch_config, 'attention_head_dim', 128) # Fixed: should be 128, not 64
60-
setattr(self.arch_config, 'num_layers', 28)
61-
setattr(self.arch_config, 'hidden_size', 2048) # 16 * 128 = 2048
62-
63-
# 4. Fix patch size to match original
64-
setattr(self.arch_config, 'patch_size', (1, 2, 2))
65-
66-
# 5. Fix max size to match original
67-
setattr(self.arch_config, 'max_size', (128, 240, 240))
68-
69-
# 6. Fix text embedding dimension
70-
setattr(self.arch_config, 'text_embed_dim', 1024)
71-
72-
# 7. Fix adaln lora dimension
73-
setattr(self.arch_config, 'adaln_lora_dim', 256)
74-
75-
# 8. Fix rope scale to match original
76-
setattr(self.arch_config, 'rope_scale', (1.0, 3.0, 3.0))
77-
78-
# 9. Enable concat padding mask
79-
setattr(self.arch_config, 'concat_padding_mask', True)
80-
81-
# 10. Set num_channels_latents to 16 (VAE output dim)
82-
setattr(self.arch_config, 'num_channels_latents', 16)
83-
84-
8533
@dataclass
8634
class CosmosConfig(PipelineConfig):
87-
"""Configuration for Cosmos2 Video2World pipeline matching original implementation."""
35+
"""Configuration for Cosmos2 Video2World pipeline matching diffusers."""
8836

89-
# DiT configuration matching Cosmos2 2B model
90-
dit_config: DiTConfig = field(default_factory=CosmosVideoConfigFixed)
37+
38+
dit_config: DiTConfig = field(default_factory=CosmosVideoConfig)
9139

92-
# VAE configuration matching Cosmos2
40+
9341
vae_config: VAEConfig = field(default_factory=CosmosVAEConfig)
9442

95-
# Text encoding configuration
43+
9644
text_encoder_configs: tuple[EncoderConfig, ...] = field(
9745
default_factory=lambda: (T5LargeConfig(), ))
9846
postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor],
9947
...] = field(default_factory=lambda:
10048
(t5_large_postprocess_text, ))
10149

102-
# Precision for each component
50+
10351
dit_precision: str = "bf16"
10452
vae_precision: str = "fp16"
10553
text_encoder_precisions: tuple[str, ...] = field(
10654
default_factory=lambda: ("bf16",))
10755

108-
# Cosmos2 Video2World specific parameters
109-
conditioning_strategy: str = "frame_replace" # Match original ConditioningStrategy.FRAME_REPLACE
56+
conditioning_strategy: str = "frame_replace"
11057
min_num_conditional_frames: int = 1
11158
max_num_conditional_frames: int = 2
11259
sigma_conditional: float = 0.0001
@@ -115,13 +62,12 @@ class CosmosConfig(PipelineConfig):
11562
state_t: int = 24
11663
text_encoder_class: str = "T5"
11764

118-
# Denoising parameters
65+
11966
embedded_cfg_scale: int = 6
120-
flow_shift: float = 1.0 # Changed to 1.0 to match diffusers (no shift transformation)
67+
flow_shift: float = 1.0
12168

12269
def __post_init__(self):
12370
self.vae_config.load_encoder = True
12471
self.vae_config.load_decoder = True
12572

126-
# Store the VAE's latent dimension to use later
127-
self._vae_latent_dim = 16 # From CosmosVAEArchConfig.z_dim
73+
self._vae_latent_dim = 16

fastvideo/configs/sample/cosmos.py

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,18 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
from dataclasses import dataclass, field
2+
from dataclasses import dataclass
33

4-
from fastvideo.configs.sample.base import CacheParams
4+
from fastvideo.configs.sample.base import SamplingParam
55

66

77
@dataclass
8-
class CosmosTeaCacheParams(CacheParams):
9-
cache_type: str = "teacache"
10-
teacache_thresh: float = 0.0
11-
use_ret_steps: bool = True
12-
ret_steps_coeffs: list[float] = field(default_factory=list)
13-
non_ret_steps_coeffs: list[float] = field(default_factory=list)
8+
class Cosmos_Predict2_2B_Video2World_SamplingParam(SamplingParam):
9+
# Video parameters
10+
height: int = 704
11+
width: int = 1280
12+
num_frames: int = 93
13+
fps: int = 16
1414

15-
@property
16-
def coefficients(self) -> list[float]:
17-
if self.use_ret_steps:
18-
return self.ret_steps_coeffs
19-
else:
20-
return self.non_ret_steps_coeffs
21-
22-
@property
23-
def ret_steps(self) -> int:
24-
if self.use_ret_steps:
25-
return 5 * 2
26-
else:
27-
return 1 * 2
28-
29-
def get_cutoff_steps(self, num_inference_steps: int) -> int:
30-
if self.use_ret_steps:
31-
return num_inference_steps * 2
32-
else:
33-
return num_inference_steps * 2 - 2
15+
# Denoising stage
16+
guidance_scale: float = 7.0
17+
negative_prompt: str = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
18+
num_inference_steps: int = 35

fastvideo/layers/layernorm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,13 @@ def __init__(
3939
if self.has_weight:
4040
self.weight = nn.Parameter(self.weight)
4141

42-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
42+
def forward_diffusers(self, hidden_states: torch.Tensor) -> torch.Tensor:
4343
"""Forward method that matches Diffusers RMSNorm implementation exactly."""
4444
input_dtype = hidden_states.dtype
4545
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
4646
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
4747

4848
if self.has_weight and self.weight is not None:
49-
# convert into half-precision if necessary (match Diffusers exactly)
5049
if self.weight.dtype in [torch.float16, torch.bfloat16]:
5150
hidden_states = hidden_states.to(self.weight.dtype)
5251
hidden_states = hidden_states * self.weight

fastvideo/layers/rotary_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def apply_rotary_emb(
6464
"""
6565
if use_real:
6666
cos, sin = freqs_cis # [S, D]
67-
# Match Diffusers exact broadcasting (sequence_dim=2 case)
67+
# Match Diffusers broadcasting (sequence_dim=2 case)
6868
cos = cos[None, None, :, :]
6969
sin = sin[None, None, :, :]
7070
cos, sin = cos.to(x.device), sin.to(x.device)

0 commit comments

Comments
 (0)