Skip to content

Commit 56be69d

Browse files
committed
Uncond pred sum
1 parent 2c78d9b commit 56be69d

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

fastvideo/pipelines/stages/denoising.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -741,13 +741,21 @@ def forward(
741741

742742
# CRITICAL: Apply conditioning frame injection like diffusers
743743
print(f"[FASTVIDEO DEBUG] Step {i}: Conditioning check - cond_indicator exists: {hasattr(batch, 'cond_indicator')}, is not None: {batch.cond_indicator is not None if hasattr(batch, 'cond_indicator') else 'N/A'}, conditioning_latents is not None: {conditioning_latents is not None}")
744+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
745+
f.write(f"[FASTVIDEO DEBUG] Step {i}: Conditioning check - cond_indicator exists: {hasattr(batch, 'cond_indicator')}, is not None: {batch.cond_indicator is not None if hasattr(batch, 'cond_indicator') else 'N/A'}, conditioning_latents is not None: {conditioning_latents is not None}\n")
744746
if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None and conditioning_latents is not None:
745747
print(f"[FASTVIDEO DEBUG] Step {i}: Before conditioning - cond_latent sum: {cond_latent.float().sum().item()}, conditioning_latents sum: {conditioning_latents.float().sum().item()}")
748+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
749+
f.write(f"[FASTVIDEO DEBUG] Step {i}: Before conditioning - cond_latent sum: {cond_latent.float().sum().item()}, conditioning_latents sum: {conditioning_latents.float().sum().item()}\n")
746750
cond_latent = batch.cond_indicator * conditioning_latents + (1 - batch.cond_indicator) * cond_latent
747751
print(f"[FASTVIDEO DEBUG] Step {i}: After conditioning - cond_latent sum: {cond_latent.float().sum().item()}")
752+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
753+
f.write(f"[FASTVIDEO DEBUG] Step {i}: After conditioning - cond_latent sum: {cond_latent.float().sum().item()}\n")
748754
logger.info(f"Step {i}: Applied conditioning frame injection - cond_latent sum: {cond_latent.float().sum().item():.6f}")
749755
else:
750756
print(f"[FASTVIDEO DEBUG] Step {i}: SKIPPING conditioning frame injection!")
757+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
758+
f.write(f"[FASTVIDEO DEBUG] Step {i}: SKIPPING conditioning frame injection!\n")
751759
logger.warning(f"Step {i}: Missing conditioning data - cond_indicator: {hasattr(batch, 'cond_indicator')}, conditioning_latents: {conditioning_latents is not None}")
752760

753761
# Convert cond_latent to target dtype BEFORE debug logging to match Diffusers
@@ -833,14 +841,40 @@ def forward(
833841
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
834842
f.write(f"CosmosDenoisingStage: step {i}, noise_pred sum = {sum_value:.6f}\n")
835843

836-
# Apply preconditioning exactly like diffusers
844+
print(f"[FASTVIDEO DEBUG] Step {i}: Preconditioning - c_skip={c_skip:.6f}, c_out={c_out:.6f}, latents_sum={latents.float().sum().item():.6f}")
845+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
846+
f.write(f"[FASTVIDEO DEBUG] Step {i}: Preconditioning - c_skip={c_skip:.6f}, c_out={c_out:.6f}, latents_sum={latents.float().sum().item():.6f}\n")
837847
cond_pred = (c_skip * latents + c_out * noise_pred.float()).to(target_dtype)
848+
849+
if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None and conditioning_latents is not None:
850+
cond_pred = batch.cond_indicator * conditioning_latents + (1 - batch.cond_indicator) * cond_pred
851+
print(f"[FASTVIDEO DEBUG] Step {i}: Applied post-preconditioning conditioning - cond_pred sum: {cond_pred.float().sum().item()}")
852+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
853+
f.write(f"[FASTVIDEO DEBUG] Step {i}: Applied post-preconditioning conditioning - cond_pred sum: {cond_pred.float().sum().item()}\n")
854+
else:
855+
print(f"[FASTVIDEO DEBUG] Step {i}: SKIPPING post-preconditioning conditioning")
856+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
857+
f.write(f"[FASTVIDEO DEBUG] Step {i}: SKIPPING post-preconditioning conditioning\n")
838858

839859
# NOTE: Conditioning frame injection is applied to cond_latent BEFORE transformer call (line 746), not after
840860
# Classifier-free guidance
841861
if batch.do_classifier_free_guidance and batch.negative_prompt_embeds is not None:
842862
# Unconditional pass - match diffusers logic (lines 755-759)
843863
uncond_latent = latents * c_in
864+
865+
print(f"[FASTVIDEO DEBUG] Step {i}: Before unconditional conditioning - uncond_latent sum: {uncond_latent.float().sum().item()}")
866+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
867+
f.write(f"[FASTVIDEO DEBUG] Step {i}: Before unconditional conditioning - uncond_latent sum: {uncond_latent.float().sum().item()}\n")
868+
869+
if hasattr(batch, 'uncond_indicator') and batch.uncond_indicator is not None and unconditioning_latents is not None:
870+
uncond_latent = batch.uncond_indicator * unconditioning_latents + (1 - batch.uncond_indicator) * uncond_latent
871+
print(f"[FASTVIDEO DEBUG] Step {i}: Applied unconditional conditioning - uncond_latent sum: {uncond_latent.float().sum().item()}")
872+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
873+
f.write(f"[FASTVIDEO DEBUG] Step {i}: Applied unconditional conditioning - uncond_latent sum: {uncond_latent.float().sum().item()}\n")
874+
else:
875+
print(f"[FASTVIDEO DEBUG] Step {i}: SKIPPING unconditional conditioning - uncond_indicator: {hasattr(batch, 'uncond_indicator')}, unconditioning_latents: {unconditioning_latents is not None}")
876+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
877+
f.write(f"[FASTVIDEO DEBUG] Step {i}: SKIPPING unconditional conditioning - uncond_indicator: {hasattr(batch, 'uncond_indicator')}, unconditioning_latents: {unconditioning_latents is not None}\n")
844878

845879
with set_forward_context(
846880
current_timestep=i,
@@ -886,8 +920,12 @@ def forward(
886920
if hasattr(batch, 'uncond_indicator') and batch.uncond_indicator is not None and unconditioning_latents is not None:
887921
uncond_pred = batch.uncond_indicator * unconditioning_latents + (1 - batch.uncond_indicator) * uncond_pred
888922

889-
# Apply guidance exactly like diffusers
890923
guidance_diff = cond_pred - uncond_pred
924+
print(f"[FASTVIDEO DEBUG] Step {i}: CFG calculation - guidance_scale = {guidance_scale}")
925+
print(f"[FASTVIDEO DEBUG] Step {i}: CFG values - cond_pred: {cond_pred.float().sum().item():.6f}, uncond_pred: {uncond_pred.float().sum().item():.6f}, guidance_diff: {guidance_diff.float().sum().item():.6f}")
926+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
927+
f.write(f"[FASTVIDEO DEBUG] Step {i}: CFG calculation - guidance_scale = {guidance_scale}\n")
928+
f.write(f"[FASTVIDEO DEBUG] Step {i}: CFG values - cond_pred: {cond_pred.float().sum().item():.6f}, uncond_pred: {uncond_pred.float().sum().item():.6f}, guidance_diff: {guidance_diff.float().sum().item():.6f}\n")
891929
final_pred = cond_pred + guidance_scale * guidance_diff
892930

893931
# Debug guidance computation

0 commit comments

Comments
 (0)