@@ -306,6 +306,18 @@ def forward(
306306 if self .vae is not None :
307307 # Move VAE to correct device before encoding
308308 self .vae = self .vae .to (device )
309+
310+ # Log VAE info and input video stats
311+ print (f"[FASTVIDEO VAE DEBUG] VAE model: { type (self .vae ).__name__ } " )
312+ print (f"[FASTVIDEO VAE DEBUG] VAE config z_dim: { self .vae .config .z_dim } " )
313+ print (f"[FASTVIDEO VAE DEBUG] Input video shape: { video .shape } , dtype: { video .dtype } , device: { video .device } " )
314+ print (f"[FASTVIDEO VAE DEBUG] Input video sum: { video .float ().sum ().item ():.6f} " )
315+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
316+ f .write (f"FastVideo VAE: model_type = { type (self .vae ).__name__ } \n " )
317+ f .write (f"FastVideo VAE: z_dim = { self .vae .config .z_dim } \n " )
318+ f .write (f"FastVideo VAE: input_video_shape = { video .shape } \n " )
319+ f .write (f"FastVideo VAE: input_video_sum = { video .float ().sum ().item ():.6f} \n " )
320+
309321 if isinstance (generator , list ):
310322 init_latents = []
311323 for i in range (batch_size ):
@@ -361,9 +373,17 @@ def forward(
361373 latents_std = torch .tensor (self .vae .config .latents_std ).view (1 , self .vae .config .z_dim , 1 , 1 , 1 ).to (device , dtype )
362374 print (f"[FASTVIDEO CONDITIONING DEBUG] latents_mean = { self .vae .config .latents_mean } , latents_std = { self .vae .config .latents_std } " )
363375 print (f"[FASTVIDEO CONDITIONING DEBUG] scheduler.sigma_data = { self .scheduler .sigma_data } " )
376+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
377+ f .write (f"FastVideo Conditioning: scheduler.sigma_data = { self .scheduler .sigma_data } \n " )
378+ f .write (f"FastVideo Conditioning: latents_mean = { self .vae .config .latents_mean } \n " )
379+ f .write (f"FastVideo Conditioning: latents_std = { self .vae .config .latents_std } \n " )
364380 print (f"[FASTVIDEO CONDITIONING DEBUG] Before normalization sum = { init_latents .float ().sum ().item ()} " )
381+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
382+ f .write (f"FastVideo Conditioning: before_normalization_sum = { init_latents .float ().sum ().item ():.6f} \n " )
365383 init_latents = (init_latents - latents_mean ) / latents_std * self .scheduler .sigma_data
366384 print (f"[FASTVIDEO CONDITIONING DEBUG] After normalization sum = { init_latents .float ().sum ().item ()} " )
385+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
386+ f .write (f"FastVideo Conditioning: after_normalization_sum = { init_latents .float ().sum ().item ():.6f} \n " )
367387
368388 conditioning_latents = init_latents
369389 print (f"[FASTVIDEO CONDITIONING DEBUG] Final conditioning_latents sum = { conditioning_latents .float ().sum ().item ()} " )
@@ -441,6 +461,38 @@ def forward(
441461 if conditioning_latents is not None :
442462 logger .info (f"CosmosLatentPreparationStage - conditioning_latents shape: { conditioning_latents .shape } " )
443463
464+ # Log tensor sums to fastvideo_hidden_states.log
465+ sum_value = latents .float ().sum ().item ()
466+ print (f"FastVideo LatentPreparation: latents sum = { sum_value :.6f} " )
467+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
468+ f .write (f"FastVideo LatentPreparation: latents sum = { sum_value :.6f} \n " )
469+
470+ if conditioning_latents is not None :
471+ sum_value = conditioning_latents .float ().sum ().item ()
472+ print (f"FastVideo LatentPreparation: conditioning_latents sum = { sum_value :.6f} " )
473+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
474+ f .write (f"FastVideo LatentPreparation: conditioning_latents sum = { sum_value :.6f} \n " )
475+
476+ sum_value = cond_indicator .float ().sum ().item ()
477+ print (f"FastVideo LatentPreparation: cond_indicator sum = { sum_value :.6f} " )
478+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
479+ f .write (f"FastVideo LatentPreparation: cond_indicator sum = { sum_value :.6f} \n " )
480+
481+ sum_value = uncond_indicator .float ().sum ().item ()
482+ print (f"FastVideo LatentPreparation: uncond_indicator sum = { sum_value :.6f} " )
483+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
484+ f .write (f"FastVideo LatentPreparation: uncond_indicator sum = { sum_value :.6f} \n " )
485+
486+ sum_value = cond_mask .float ().sum ().item ()
487+ print (f"FastVideo LatentPreparation: cond_mask sum = { sum_value :.6f} " )
488+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
489+ f .write (f"FastVideo LatentPreparation: cond_mask sum = { sum_value :.6f} \n " )
490+
491+ sum_value = uncond_mask .float ().sum ().item ()
492+ print (f"FastVideo LatentPreparation: uncond_mask sum = { sum_value :.6f} " )
493+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
494+ f .write (f"FastVideo LatentPreparation: uncond_mask sum = { sum_value :.6f} \n " )
495+
444496 return batch
445497
446498 def verify_input (self , batch : ForwardBatch ,
0 commit comments