Skip to content

Commit 6fa24b7

Browse files
committed
Add logging
1 parent adf49de commit 6fa24b7

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

fastvideo/models/dits/cosmos.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,7 @@ def forward(self,
688688
condition_mask: torch.Tensor | None = None,
689689
padding_mask: torch.Tensor | None = None,
690690
**kwargs) -> torch.Tensor:
691+
print(f"[FASTVIDEO TRANSFORMER] Input hidden_states sum = {hidden_states.float().sum().item()}")
691692
forward_batch = get_forward_context().forward_batch
692693
enable_teacache = forward_batch is not None and forward_batch.enable_teacache
693694

@@ -796,4 +797,5 @@ def forward(self,
796797
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
797798
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
798799

800+
print(f"[FASTVIDEO TRANSFORMER] Output hidden_states sum = {hidden_states.float().sum().item()}")
799801
return hidden_states

fastvideo/pipelines/stages/denoising.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,11 @@ def forward(
690690
conditioning_latents = getattr(batch, 'conditioning_latents', None)
691691
unconditioning_latents = conditioning_latents # Same for cosmos
692692

693+
# Add sigma_conditioning logic like diffusers (line 694-695)
694+
# sigma_conditioning = 0.0001 # Default value from diffusers
695+
# sigma_conditioning_tensor = torch.tensor(sigma_conditioning, dtype=torch.float32, device=latents.device)
696+
# t_conditioning = sigma_conditioning_tensor / (sigma_conditioning_tensor + 1)
697+
693698
# Sampling loop
694699
with self.progress_bar(total=num_inference_steps) as progress_bar:
695700
for i, t in enumerate(timesteps):
@@ -709,20 +714,39 @@ def forward(
709714
logger.info(f"Step {i}: current_sigma={current_sigma:.6f}, current_t={current_t:.6f}")
710715
logger.info(f"Step {i}: c_in={c_in:.6f}, c_skip={c_skip:.6f}, c_out={c_out:.6f}")
711716

712-
# Prepare timestep tensor
717+
# Prepare timestep tensor like diffusers (lines 713-715)
713718
timestep = current_t.view(1, 1, 1, 1, 1).expand(
714719
latents.size(0), -1, latents.size(2), -1, -1
715-
)
720+
) # [B, 1, T, 1, 1]
716721

717722
with torch.autocast(device_type="cuda",
718723
dtype=target_dtype,
719724
enabled=autocast_enabled):
720725

721-
# Conditional forward pass
726+
# Conditional forward pass - match diffusers exactly (lines 717-721)
722727
cond_latent = latents * c_in
723-
# Add conditional frame handling like diffusers
728+
print(f"[FASTVIDEO DEBUG] Step {i}: After latents * c_in, cond_latent sum = {cond_latent.float().sum().item()}")
729+
730+
# CRITICAL: Apply conditioning frame injection like diffusers
731+
print(f"[FASTVIDEO DEBUG] Step {i}: Conditioning check - cond_indicator exists: {hasattr(batch, 'cond_indicator')}, is not None: {batch.cond_indicator is not None if hasattr(batch, 'cond_indicator') else 'N/A'}, conditioning_latents is not None: {conditioning_latents is not None}")
724732
if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None and conditioning_latents is not None:
733+
print(f"[FASTVIDEO DEBUG] Step {i}: Before conditioning - cond_latent sum: {cond_latent.float().sum().item()}, conditioning_latents sum: {conditioning_latents.float().sum().item()}")
725734
cond_latent = batch.cond_indicator * conditioning_latents + (1 - batch.cond_indicator) * cond_latent
735+
print(f"[FASTVIDEO DEBUG] Step {i}: After conditioning - cond_latent sum: {cond_latent.float().sum().item()}")
736+
logger.info(f"Step {i}: Applied conditioning frame injection - cond_latent sum: {cond_latent.float().sum().item():.6f}")
737+
else:
738+
print(f"[FASTVIDEO DEBUG] Step {i}: SKIPPING conditioning frame injection!")
739+
logger.warning(f"Step {i}: Missing conditioning data - cond_indicator: {hasattr(batch, 'cond_indicator')}, conditioning_latents: {conditioning_latents is not None}")
740+
741+
# cond_latent = cond_latent.to(target_dtype)
742+
743+
# # Apply conditional timestep processing like diffusers (lines 720-721)
744+
# cond_timestep = timestep
745+
# if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None:
746+
# cond_timestep = batch.cond_indicator * t_conditioning + (1 - batch.cond_indicator) * timestep
747+
# cond_timestep = cond_timestep.to(target_dtype)
748+
# if i < 3:
749+
# logger.info(f"Step {i}: Applied conditional timestep - t_conditioning: {t_conditioning:.6f}, cond_timestep sum: {cond_timestep.float().sum().item():.6f}")
726750

727751
with set_forward_context(
728752
current_timestep=i,
@@ -750,6 +774,7 @@ def forward(
750774
logger.info(f" condition_mask shape: {condition_mask.shape if condition_mask is not None else None}")
751775
logger.info(f" padding_mask shape: {padding_mask.shape}")
752776

777+
print(f"[FASTVIDEO DENOISING] About to call transformer with hidden_states sum = {cond_latent.float().sum().item()}")
753778
noise_pred = self.transformer(
754779
hidden_states=cond_latent.to(target_dtype),
755780
timestep=timestep.to(target_dtype),
@@ -775,6 +800,7 @@ def forward(
775800

776801
# Classifier-free guidance
777802
if batch.do_classifier_free_guidance and batch.negative_prompt_embeds is not None:
803+
# Unconditional pass - match diffusers logic (lines 755-759)
778804
uncond_latent = latents * c_in
779805

780806
with set_forward_context(
@@ -790,6 +816,7 @@ def forward(
790816
logger.info(f"Step {i}: Uncond transformer inputs:")
791817
logger.info(f" uncond_latent sum: {uncond_latent.float().sum().item():.6f}")
792818
logger.info(f" negative_prompt_embeds shape: {batch.negative_prompt_embeds[0].shape}")
819+
# sum: {uncond_timestep.float().sum().item():.6f}")
793820

794821
noise_pred_uncond = self.transformer(
795822
hidden_states=uncond_latent.to(target_dtype),

fastvideo/pipelines/stages/latent_preparation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,14 +353,20 @@ def forward(
353353
init_latents = init_latents_list
354354

355355
init_latents = torch.cat(init_latents, dim=0).to(dtype)
356+
print(f"[FASTVIDEO CONDITIONING DEBUG] Raw VAE latents sum = {init_latents.float().sum().item()}")
356357

357358
# Apply latent normalization like diffusers
358359
if hasattr(self.vae.config, 'latents_mean') and hasattr(self.vae.config, 'latents_std'):
359360
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
360361
latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
362+
print(f"[FASTVIDEO CONDITIONING DEBUG] latents_mean = {self.vae.config.latents_mean}, latents_std = {self.vae.config.latents_std}")
363+
print(f"[FASTVIDEO CONDITIONING DEBUG] scheduler.sigma_data = {self.scheduler.sigma_data}")
364+
print(f"[FASTVIDEO CONDITIONING DEBUG] Before normalization sum = {init_latents.float().sum().item()}")
361365
init_latents = (init_latents - latents_mean) / latents_std * self.scheduler.sigma_data
366+
print(f"[FASTVIDEO CONDITIONING DEBUG] After normalization sum = {init_latents.float().sum().item()}")
362367

363368
conditioning_latents = init_latents
369+
print(f"[FASTVIDEO CONDITIONING DEBUG] Final conditioning_latents sum = {conditioning_latents.float().sum().item()}")
364370

365371
# Offload VAE to CPU after encoding to save memory
366372
self.vae.to("cpu")

0 commit comments

Comments
 (0)