Skip to content

Commit 6599352

Browse files
committed
Add addt logging
1 parent 6fa24b7 commit 6599352

File tree

2 files changed

+68
-4
lines changed

2 files changed

+68
-4
lines changed

fastvideo/pipelines/stages/denoising.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -660,19 +660,31 @@ def forward(
660660
pipeline.add_module("transformer", self.transformer)
661661
fastvideo_args.model_loaded["transformer"] = True
662662

663-
# Setup precision and autocast settings
664-
target_dtype = torch.bfloat16
663+
# Setup precision to match diffusers exactly
664+
# Diffusers uses transformer.dtype (bfloat16) and converts inputs before transformer calls
665+
# For FSDP wrapped models, we need to access the underlying module
666+
if hasattr(self.transformer, 'module'):
667+
transformer_dtype = next(self.transformer.module.parameters()).dtype
668+
else:
669+
transformer_dtype = next(self.transformer.parameters()).dtype
670+
target_dtype = transformer_dtype
665671
autocast_enabled = (target_dtype != torch.float32
666672
) and not fastvideo_args.disable_autocast
667673

668674
# Get latents and setup
669675
latents = batch.latents
670676
num_inference_steps = batch.num_inference_steps
671677
guidance_scale = batch.guidance_scale
678+
679+
sum_value = latents.float().sum().item()
680+
# Write to output file
681+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
682+
f.write(f"Denoising init: latents sum = {sum_value:.6f}, shape = {latents.shape}\n")
672683

673684

674-
# Setup scheduler timesteps like Diffusers does
675-
# Diffusers uses set_timesteps without custom sigmas, letting the scheduler generate them
685+
# Setup scheduler timesteps - use default scheduler sigma generation
686+
# The torch.linspace(0, 1, num_inference_steps) approach was incorrect for FlowMatchEulerDiscreteScheduler
687+
# Let the scheduler generate its own sigmas using the configured sigma_max, sigma_min, etc.
676688
self.scheduler.set_timesteps(num_inference_steps, device=latents.device)
677689
timesteps = self.scheduler.timesteps
678690

fastvideo/pipelines/stages/latent_preparation.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,18 @@ def forward(
306306
if self.vae is not None:
307307
# Move VAE to correct device before encoding
308308
self.vae = self.vae.to(device)
309+
310+
# Log VAE info and input video stats
311+
print(f"[FASTVIDEO VAE DEBUG] VAE model: {type(self.vae).__name__}")
312+
print(f"[FASTVIDEO VAE DEBUG] VAE config z_dim: {self.vae.config.z_dim}")
313+
print(f"[FASTVIDEO VAE DEBUG] Input video shape: {video.shape}, dtype: {video.dtype}, device: {video.device}")
314+
print(f"[FASTVIDEO VAE DEBUG] Input video sum: {video.float().sum().item():.6f}")
315+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
316+
f.write(f"FastVideo VAE: model_type = {type(self.vae).__name__}\n")
317+
f.write(f"FastVideo VAE: z_dim = {self.vae.config.z_dim}\n")
318+
f.write(f"FastVideo VAE: input_video_shape = {video.shape}\n")
319+
f.write(f"FastVideo VAE: input_video_sum = {video.float().sum().item():.6f}\n")
320+
309321
if isinstance(generator, list):
310322
init_latents = []
311323
for i in range(batch_size):
@@ -361,9 +373,17 @@ def forward(
361373
latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
362374
print(f"[FASTVIDEO CONDITIONING DEBUG] latents_mean = {self.vae.config.latents_mean}, latents_std = {self.vae.config.latents_std}")
363375
print(f"[FASTVIDEO CONDITIONING DEBUG] scheduler.sigma_data = {self.scheduler.sigma_data}")
376+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
377+
f.write(f"FastVideo Conditioning: scheduler.sigma_data = {self.scheduler.sigma_data}\n")
378+
f.write(f"FastVideo Conditioning: latents_mean = {self.vae.config.latents_mean}\n")
379+
f.write(f"FastVideo Conditioning: latents_std = {self.vae.config.latents_std}\n")
364380
print(f"[FASTVIDEO CONDITIONING DEBUG] Before normalization sum = {init_latents.float().sum().item()}")
381+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
382+
f.write(f"FastVideo Conditioning: before_normalization_sum = {init_latents.float().sum().item():.6f}\n")
365383
init_latents = (init_latents - latents_mean) / latents_std * self.scheduler.sigma_data
366384
print(f"[FASTVIDEO CONDITIONING DEBUG] After normalization sum = {init_latents.float().sum().item()}")
385+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
386+
f.write(f"FastVideo Conditioning: after_normalization_sum = {init_latents.float().sum().item():.6f}\n")
367387

368388
conditioning_latents = init_latents
369389
print(f"[FASTVIDEO CONDITIONING DEBUG] Final conditioning_latents sum = {conditioning_latents.float().sum().item()}")
@@ -441,6 +461,38 @@ def forward(
441461
if conditioning_latents is not None:
442462
logger.info(f"CosmosLatentPreparationStage - conditioning_latents shape: {conditioning_latents.shape}")
443463

464+
# Log tensor sums to fastvideo_hidden_states.log
465+
sum_value = latents.float().sum().item()
466+
print(f"FastVideo LatentPreparation: latents sum = {sum_value:.6f}")
467+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
468+
f.write(f"FastVideo LatentPreparation: latents sum = {sum_value:.6f}\n")
469+
470+
if conditioning_latents is not None:
471+
sum_value = conditioning_latents.float().sum().item()
472+
print(f"FastVideo LatentPreparation: conditioning_latents sum = {sum_value:.6f}")
473+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
474+
f.write(f"FastVideo LatentPreparation: conditioning_latents sum = {sum_value:.6f}\n")
475+
476+
sum_value = cond_indicator.float().sum().item()
477+
print(f"FastVideo LatentPreparation: cond_indicator sum = {sum_value:.6f}")
478+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
479+
f.write(f"FastVideo LatentPreparation: cond_indicator sum = {sum_value:.6f}\n")
480+
481+
sum_value = uncond_indicator.float().sum().item()
482+
print(f"FastVideo LatentPreparation: uncond_indicator sum = {sum_value:.6f}")
483+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
484+
f.write(f"FastVideo LatentPreparation: uncond_indicator sum = {sum_value:.6f}\n")
485+
486+
sum_value = cond_mask.float().sum().item()
487+
print(f"FastVideo LatentPreparation: cond_mask sum = {sum_value:.6f}")
488+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
489+
f.write(f"FastVideo LatentPreparation: cond_mask sum = {sum_value:.6f}\n")
490+
491+
sum_value = uncond_mask.float().sum().item()
492+
print(f"FastVideo LatentPreparation: uncond_mask sum = {sum_value:.6f}")
493+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
494+
f.write(f"FastVideo LatentPreparation: uncond_mask sum = {sum_value:.6f}\n")
495+
444496
return batch
445497

446498
def verify_input(self, batch: ForwardBatch,

0 commit comments

Comments
 (0)