@@ -23,90 +23,37 @@ def t5_large_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor:
2323 if hidden_state is None :
2424 raise ValueError ("T5 Large outputs missing last_hidden_state" )
2525
26- # Check for NaN values and provide debugging info
2726 nan_count = torch .isnan (hidden_state ).sum ()
2827 if nan_count > 0 :
29- print (f"WARNING: Found { nan_count } NaN values in T5 Large hidden states" )
30- print (f"Hidden state shape: { hidden_state .shape } " )
31- print (f"Hidden state dtype: { hidden_state .dtype } " )
32- print (f"Hidden state device: { hidden_state .device } " )
33- # Replace NaN values with zeros to avoid pipeline failure
3428 hidden_state = hidden_state .masked_fill (torch .isnan (hidden_state ), 0.0 )
3529
36- # Return raw last_hidden_state (no truncation/padding)
3730 return hidden_state
3831
3932
40- @dataclass
41- class CosmosVideoConfigFixed (CosmosVideoConfig ):
42- """Fixed Cosmos Video Config that matches original Cosmos2 Video2World configuration."""
43-
44- def update_model_arch (self , config : dict ) -> None :
45- """Update model architecture config with HF config, but fix parameters to match original Cosmos2."""
46- # First, apply the standard update
47- super ().update_model_arch (config )
48-
49- # CRITICAL FIXES to match original Cosmos2 Video2World configuration:
50-
51- # 1. Fix input channels: should be 16 (VAE) + 1 (condition mask) = 17
52- setattr (self .arch_config , 'in_channels' , 17 )
53-
54- # 2. Fix output channels: should be 16 (VAE latent dimension)
55- setattr (self .arch_config , 'out_channels' , 16 )
56-
57- # 3. Fix model architecture to match Cosmos2 2B model
58- setattr (self .arch_config , 'num_attention_heads' , 16 )
59- setattr (self .arch_config , 'attention_head_dim' , 128 ) # Fixed: should be 128, not 64
60- setattr (self .arch_config , 'num_layers' , 28 )
61- setattr (self .arch_config , 'hidden_size' , 2048 ) # 16 * 128 = 2048
62-
63- # 4. Fix patch size to match original
64- setattr (self .arch_config , 'patch_size' , (1 , 2 , 2 ))
65-
66- # 5. Fix max size to match original
67- setattr (self .arch_config , 'max_size' , (128 , 240 , 240 ))
68-
69- # 6. Fix text embedding dimension
70- setattr (self .arch_config , 'text_embed_dim' , 1024 )
71-
72- # 7. Fix adaln lora dimension
73- setattr (self .arch_config , 'adaln_lora_dim' , 256 )
74-
75- # 8. Fix rope scale to match original
76- setattr (self .arch_config , 'rope_scale' , (1.0 , 3.0 , 3.0 ))
77-
78- # 9. Enable concat padding mask
79- setattr (self .arch_config , 'concat_padding_mask' , True )
80-
81- # 10. Set num_channels_latents to 16 (VAE output dim)
82- setattr (self .arch_config , 'num_channels_latents' , 16 )
83-
84-
8533@dataclass
8634class CosmosConfig (PipelineConfig ):
87- """Configuration for Cosmos2 Video2World pipeline matching original implementation ."""
35+ """Configuration for Cosmos2 Video2World pipeline matching diffusers ."""
8836
89- # DiT configuration matching Cosmos2 2B model
90- dit_config : DiTConfig = field (default_factory = CosmosVideoConfigFixed )
37+
38+ dit_config : DiTConfig = field (default_factory = CosmosVideoConfig )
9139
92- # VAE configuration matching Cosmos2
40+
9341 vae_config : VAEConfig = field (default_factory = CosmosVAEConfig )
9442
95- # Text encoding configuration
43+
9644 text_encoder_configs : tuple [EncoderConfig , ...] = field (
9745 default_factory = lambda : (T5LargeConfig (), ))
9846 postprocess_text_funcs : tuple [Callable [[BaseEncoderOutput ], torch .Tensor ],
9947 ...] = field (default_factory = lambda :
10048 (t5_large_postprocess_text , ))
10149
102- # Precision for each component
50+
10351 dit_precision : str = "bf16"
10452 vae_precision : str = "fp16"
10553 text_encoder_precisions : tuple [str , ...] = field (
10654 default_factory = lambda : ("bf16" ,))
10755
108- # Cosmos2 Video2World specific parameters
109- conditioning_strategy : str = "frame_replace" # Match original ConditioningStrategy.FRAME_REPLACE
56+ conditioning_strategy : str = "frame_replace"
11057 min_num_conditional_frames : int = 1
11158 max_num_conditional_frames : int = 2
11259 sigma_conditional : float = 0.0001
@@ -115,13 +62,12 @@ class CosmosConfig(PipelineConfig):
11562 state_t : int = 24
11663 text_encoder_class : str = "T5"
11764
118- # Denoising parameters
65+
11966 embedded_cfg_scale : int = 6
120- flow_shift : float = 1.0 # Changed to 1.0 to match diffusers (no shift transformation)
67+ flow_shift : float = 1.0
12168
12269 def __post_init__ (self ):
12370 self .vae_config .load_encoder = True
12471 self .vae_config .load_decoder = True
12572
126- # Store the VAE's latent dimension to use later
127- self ._vae_latent_dim = 16 # From CosmosVAEArchConfig.z_dim
73+ self ._vae_latent_dim = 16
0 commit comments