@@ -82,6 +82,10 @@ def forward(
8282 if latents is None :
8383 raise ValueError ("Latents must be provided" )
8484
85+ print (f"[FASTVIDEO VAE DEBUG] Before scaling/shifting - latents sum: { latents .float ().sum ().item ():.6f} , shape: { latents .shape } , dtype: { latents .dtype } " )
86+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
87+ f .write (f"[FASTVIDEO VAE DEBUG] Before scaling/shifting - latents sum: { latents .float ().sum ().item ():.6f} , shape: { latents .shape } , dtype: { latents .dtype } \n " )
88+
8589 # Skip decoding if output type is latent
8690 if fastvideo_args .output_type == "latent" :
8791 image = latents
@@ -92,18 +96,40 @@ def forward(
9296 vae_autocast_enabled = (vae_dtype != torch .float32
9397 ) and not fastvideo_args .disable_autocast
9498
95- # TEMPORARY: Handle diffusers VAE compatibility
96- if hasattr (self .vae , 'scaling_factor' ):
99+ # Apply latents normalization for Cosmos VAE (matching diffusers implementation)
100+ # Source: /workspace/diffusers/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py:1000-1010
101+ if hasattr (self .vae , 'config' ) and hasattr (self .vae .config , 'latents_mean' ) and hasattr (self .vae .config , 'latents_std' ):
102+ # Get scheduler for sigma_data
103+ pipeline = self .pipeline () if self .pipeline else None
104+ sigma_data = 1.0 # default
105+ if pipeline and hasattr (pipeline , 'modules' ) and 'scheduler' in pipeline .modules :
106+ scheduler = pipeline .modules ['scheduler' ]
107+ if hasattr (scheduler , 'config' ) and hasattr (scheduler .config , 'sigma_data' ):
108+ sigma_data = scheduler .config .sigma_data
109+
110+ latents_mean = (
111+ torch .tensor (self .vae .config .latents_mean )
112+ .view (1 , self .vae .config .z_dim , 1 , 1 , 1 )
113+ .to (latents .device , latents .dtype )
114+ )
115+ latents_std = (
116+ torch .tensor (self .vae .config .latents_std )
117+ .view (1 , self .vae .config .z_dim , 1 , 1 , 1 )
118+ .to (latents .device , latents .dtype )
119+ )
120+ latents = latents * latents_std / sigma_data + latents_mean
121+ # Fallback to scaling_factor for other VAE types
122+ elif hasattr (self .vae , 'scaling_factor' ):
97123 if isinstance (self .vae .scaling_factor , torch .Tensor ):
98124 latents = latents / self .vae .scaling_factor .to (
99125 latents .device , latents .dtype )
100126 else :
101127 latents = latents / self .vae .scaling_factor
102128 elif hasattr (self .vae , 'config' ) and hasattr (self .vae .config , 'scaling_factor' ):
103- # Fallback to config scaling factor for diffusers VAE
129+ # Fallback to config scaling factor for other diffusers VAEs
104130 latents = latents / self .vae .config .scaling_factor
105131
106- # Apply shifting if needed
132+ # Apply shifting if needed (for other VAE types)
107133 if (hasattr (self .vae , "shift_factor" )
108134 and self .vae .shift_factor is not None ):
109135 if isinstance (self .vae .shift_factor , torch .Tensor ):
@@ -112,6 +138,10 @@ def forward(
112138 else :
113139 latents += self .vae .shift_factor
114140
141+ print (f"[FASTVIDEO VAE DEBUG] After scaling/shifting - latents sum: { latents .float ().sum ().item ():.6f} " )
142+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
143+ f .write (f"[FASTVIDEO VAE DEBUG] After scaling/shifting - latents sum: { latents .float ().sum ().item ():.6f} \n " )
144+
115145 # Decode latents
116146 with torch .autocast (device_type = "cuda" ,
117147 dtype = vae_dtype ,
@@ -132,9 +162,17 @@ def forward(
132162 # FastVideo VAE returns tensor directly
133163 image = decode_output
134164
165+ print (f"[FASTVIDEO VAE DEBUG] After decode - image sum: { image .float ().sum ().item ():.6f} , shape: { image .shape } " )
166+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
167+ f .write (f"[FASTVIDEO VAE DEBUG] After decode - image sum: { image .float ().sum ().item ():.6f} , shape: { image .shape } \n " )
168+
135169 # Normalize image to [0, 1] range
136170 image = (image / 2 + 0.5 ).clamp (0 , 1 )
137171
172+ print (f"[FASTVIDEO VAE DEBUG] After normalization - image sum: { image .float ().sum ().item ():.6f} " )
173+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
174+ f .write (f"[FASTVIDEO VAE DEBUG] After normalization - image sum: { image .float ().sum ().item ():.6f} \n " )
175+
138176 # Convert to CPU float32 for compatibility
139177 image = image .cpu ().float ()
140178
0 commit comments