Skip to content

Commit 1edf638

Browse files
committed
Update
1 parent 8fafbad commit 1edf638

File tree

5 files changed

+219
-69
lines changed

5 files changed

+219
-69
lines changed

fastvideo/pipelines/basic/cosmos/cosmos_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
9090
self.add_stage(stage_name="latent_preparation_stage",
9191
stage=CosmosLatentPreparationStage(
9292
scheduler=self.get_module("scheduler"),
93-
transformer=self.get_module("transformer")))
93+
transformer=self.get_module("transformer"),
94+
vae=self.get_module("vae")))
9495

9596
# Denoising loop - corresponds to main denoising loop in __call__
9697
# Source: /workspace/diffusers/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py:673-752

fastvideo/pipelines/stages/denoising.py

Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ class CosmosDenoisingStage(DenoisingStage):
624624
"""
625625
Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.
626626
627-
This stage implements the diffusers-compatible Cosmos denoising process with velocity prediction,
627+
This stage implements the diffusers-compatible Cosmos denoising process with noise prediction,
628628
classifier-free guidance, and conditional video generation support.
629629
Compatible with Hugging Face Cosmos models.
630630
"""
@@ -671,15 +671,24 @@ def forward(
671671
guidance_scale = batch.guidance_scale
672672

673673

674-
# Setup scheduler timesteps (let scheduler generate proper sigmas)
674+
# Setup scheduler timesteps like Diffusers does
675+
# Diffusers uses set_timesteps without custom sigmas, letting the scheduler generate them
675676
self.scheduler.set_timesteps(num_inference_steps, device=latents.device)
676677
timesteps = self.scheduler.timesteps
677678

678-
# Initialize with maximum noise
679-
latents = torch.randn_like(latents, dtype=torch.float32) * self.scheduler.sigma_max
679+
# Handle final sigmas like diffusers
680+
if hasattr(self.scheduler.config, 'final_sigmas_type') and self.scheduler.config.final_sigmas_type == "sigma_min":
681+
if len(self.scheduler.sigmas) > 1:
682+
self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
680683

681-
# Prepare conditional frame handling (if needed)
682-
# This would be implemented based on batch.conditioning_latents or similar
684+
# Debug: Log sigma information
685+
logger.info(f"CosmosDenoisingStage - Scheduler sigmas shape: {self.scheduler.sigmas.shape}")
686+
logger.info(f"CosmosDenoisingStage - Sigma range: {self.scheduler.sigmas.min():.6f} to {self.scheduler.sigmas.max():.6f}")
687+
logger.info(f"CosmosDenoisingStage - First few sigmas: {self.scheduler.sigmas[:5]}")
688+
689+
# Get conditioning setup from batch (prepared by CosmosLatentPreparationStage)
690+
conditioning_latents = getattr(batch, 'conditioning_latents', None)
691+
unconditioning_latents = conditioning_latents # Same for cosmos
683692

684693
# Sampling loop
685694
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -695,6 +704,11 @@ def forward(
695704
c_skip = 1 - current_t
696705
c_out = -current_t
697706

707+
# Debug: Log sigma and coefficients for first few steps
708+
if i < 3:
709+
logger.info(f"Step {i}: current_sigma={current_sigma:.6f}, current_t={current_t:.6f}")
710+
logger.info(f"Step {i}: c_in={c_in:.6f}, c_skip={c_skip:.6f}, c_out={c_out:.6f}")
711+
698712
# Prepare timestep tensor
699713
timestep = current_t.view(1, 1, 1, 1, 1).expand(
700714
latents.size(0), -1, latents.size(2), -1, -1
@@ -706,8 +720,9 @@ def forward(
706720

707721
# Conditional forward pass
708722
cond_latent = latents * c_in
709-
# Add conditional frame handling here if needed:
710-
# cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent
723+
# Add conditional frame handling like diffusers
724+
if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None and conditioning_latents is not None:
725+
cond_latent = batch.cond_indicator * conditioning_latents + (1 - batch.cond_indicator) * cond_latent
711726

712727
with set_forward_context(
713728
current_timestep=i,
@@ -726,7 +741,16 @@ def forward(
726741
device=cond_latent.device, dtype=target_dtype)
727742

728743

729-
cond_velocity = self.transformer(
744+
# Debug transformer inputs for first few steps
745+
if i < 3:
746+
logger.info(f"Step {i}: Transformer inputs:")
747+
logger.info(f" cond_latent shape: {cond_latent.shape}, sum: {cond_latent.float().sum().item():.6f}")
748+
logger.info(f" timestep shape: {timestep.shape}, values: {timestep.flatten()[:5]}")
749+
logger.info(f" prompt_embeds shape: {batch.prompt_embeds[0].shape}")
750+
logger.info(f" condition_mask shape: {condition_mask.shape if condition_mask is not None else None}")
751+
logger.info(f" padding_mask shape: {padding_mask.shape}")
752+
753+
noise_pred = self.transformer(
730754
hidden_states=cond_latent.to(target_dtype),
731755
timestep=timestep.to(target_dtype),
732756
encoder_hidden_states=batch.prompt_embeds[0].to(target_dtype),
@@ -735,14 +759,14 @@ def forward(
735759
padding_mask=padding_mask,
736760
return_dict=False,
737761
)[0]
738-
sum_value = cond_velocity.float().sum().item()
739-
logger.info(f"CosmosDenoisingStage: step {i}, cond_velocity sum = {sum_value:.6f}")
762+
sum_value = noise_pred.float().sum().item()
763+
logger.info(f"CosmosDenoisingStage: step {i}, noise_pred sum = {sum_value:.6f}")
740764
# Write to output file
741765
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
742-
f.write(f"CosmosDenoisingStage: step {i}, cond_velocity sum = {sum_value:.6f}\n")
766+
f.write(f"CosmosDenoisingStage: step {i}, noise_pred sum = {sum_value:.6f}\n")
743767

744-
# Apply preconditioning and conditional masking
745-
cond_pred = (c_skip * latents + c_out * cond_velocity.float()).to(target_dtype)
768+
# Apply preconditioning exactly like diffusers
769+
cond_pred = (c_skip * latents + c_out * noise_pred.float()).to(target_dtype)
746770

747771
# Apply conditional indicator masking (from CosmosLatentPreparationStage)
748772
if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None:
@@ -761,7 +785,13 @@ def forward(
761785
# Use uncond_mask for unconditional pass if available
762786
uncond_condition_mask = batch.uncond_mask.to(target_dtype) if hasattr(batch, 'uncond_mask') and batch.uncond_mask is not None else condition_mask
763787

764-
uncond_velocity = self.transformer(
788+
# Debug unconditional transformer inputs for first few steps
789+
if i < 3:
790+
logger.info(f"Step {i}: Uncond transformer inputs:")
791+
logger.info(f" uncond_latent sum: {uncond_latent.float().sum().item():.6f}")
792+
logger.info(f" negative_prompt_embeds shape: {batch.negative_prompt_embeds[0].shape}")
793+
794+
noise_pred_uncond = self.transformer(
765795
hidden_states=uncond_latent.to(target_dtype),
766796
timestep=timestep.to(target_dtype),
767797
encoder_hidden_states=batch.negative_prompt_embeds[0].to(target_dtype),
@@ -770,60 +800,62 @@ def forward(
770800
padding_mask=padding_mask,
771801
return_dict=False,
772802
)[0]
773-
sum_value = uncond_velocity.float().sum().item()
774-
logger.info(f"CosmosDenoisingStage: step {i}, uncond_velocity sum = {sum_value:.6f}")
803+
sum_value = noise_pred_uncond.float().sum().item()
804+
logger.info(f"CosmosDenoisingStage: step {i}, noise_pred_uncond sum = {sum_value:.6f}")
775805
# Write to output file
776806
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
777-
f.write(f"CosmosDenoisingStage: step {i}, uncond_velocity sum = {sum_value:.6f}\n")
807+
f.write(f"CosmosDenoisingStage: step {i}, noise_pred_uncond sum = {sum_value:.6f}\n")
778808

779-
uncond_pred = (c_skip * latents + c_out * uncond_velocity.float()).to(target_dtype)
809+
uncond_pred = (c_skip * latents + c_out * noise_pred_uncond.float()).to(target_dtype)
780810

781-
# Apply conditional indicator masking for unconditional prediction
782-
if hasattr(batch, 'uncond_indicator') and batch.uncond_indicator is not None:
783-
unconditioning_latents = conditioning_latents # Same as conditioning for cosmos
811+
# Apply conditional indicator masking for unconditional prediction like diffusers
812+
if hasattr(batch, 'uncond_indicator') and batch.uncond_indicator is not None and unconditioning_latents is not None:
784813
uncond_pred = batch.uncond_indicator * unconditioning_latents + (1 - batch.uncond_indicator) * uncond_pred
785814

786-
# Apply guidance
787-
noise_pred = cond_pred + guidance_scale * (cond_pred - uncond_pred)
788-
sum_value = noise_pred.float().sum().item()
815+
# Apply guidance exactly like diffusers
816+
guidance_diff = cond_pred - uncond_pred
817+
final_pred = cond_pred + guidance_scale * guidance_diff
818+
819+
# Debug guidance computation
820+
if i < 3: # Log first few steps
821+
logger.info(f"Step {i}: Guidance debug:")
822+
logger.info(f" cond_pred sum = {cond_pred.float().sum().item():.6f}")
823+
logger.info(f" uncond_pred sum = {uncond_pred.float().sum().item():.6f}")
824+
logger.info(f" guidance_diff sum = {guidance_diff.float().sum().item():.6f}")
825+
logger.info(f" guidance_scale = {guidance_scale}")
826+
logger.info(f" final_pred sum = {final_pred.float().sum().item():.6f}")
827+
828+
sum_value = final_pred.float().sum().item()
789829
logger.info(f"CosmosDenoisingStage: step {i}, final noise_pred sum = {sum_value:.6f}")
790830
# Write to output file
791831
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
792832
f.write(f"CosmosDenoisingStage: step {i}, final noise_pred sum = {sum_value:.6f}\n")
793833
else:
794-
noise_pred = cond_pred
795-
796-
# Debug: Check for NaN values before conversion
797-
logger.info(f"Step {i}: Before conversion - latents NaN count: {torch.isnan(latents).sum()}")
798-
logger.info(f"Step {i}: Before conversion - noise_pred NaN count: {torch.isnan(noise_pred).sum()}")
799-
logger.info(f"Step {i}: current_sigma: {current_sigma}")
834+
final_pred = cond_pred
835+
if i < 3:
836+
logger.info(f"Step {i}: No CFG, using cond_pred directly: {final_pred.float().sum().item():.6f}")
800837

801-
# Convert from velocity prediction to noise (Cosmos-specific conversion)
802-
# Add small epsilon to prevent division by zero
803-
current_sigma_safe = torch.clamp(current_sigma, min=1e-8)
804-
noise_pred = (latents - noise_pred) / current_sigma_safe
838+
# Convert to noise for scheduler step exactly like diffusers
839+
# Add safety check to prevent division by zero
840+
if current_sigma > 1e-8:
841+
noise_for_scheduler = (latents - final_pred) / current_sigma
842+
else:
843+
logger.warning(f"Step {i}: current_sigma too small ({current_sigma}), using final_pred directly")
844+
noise_for_scheduler = final_pred
805845

806-
# Debug: Check for NaN values after conversion
807-
logger.info(f"Step {i}: After conversion - noise_pred NaN count: {torch.isnan(noise_pred).sum()}")
846+
# Debug: Check for NaN values before scheduler step
847+
if torch.isnan(noise_for_scheduler).sum() > 0:
848+
logger.error(f"Step {i}: NaN detected in noise_for_scheduler, sum: {noise_for_scheduler.float().sum().item()}")
849+
logger.error(f"Step {i}: latents sum: {latents.float().sum().item()}, final_pred sum: {final_pred.float().sum().item()}, current_sigma: {current_sigma}")
808850

809-
# Standard scheduler step
810-
latents_before = latents.clone()
811-
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
851+
# Standard scheduler step like diffusers
852+
latents = self.scheduler.step(noise_for_scheduler, t, latents, return_dict=False)[0]
812853
sum_value = latents.float().sum().item()
813854
logger.info(f"CosmosDenoisingStage: step {i}, updated latents sum = {sum_value:.6f}")
814855
# Write to output file
815856
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
816857
f.write(f"CosmosDenoisingStage: step {i}, updated latents sum = {sum_value:.6f}\n")
817858

818-
# Debug: Check for NaN values after scheduler step
819-
logger.info(f"Step {i}: After scheduler - latents NaN count: {torch.isnan(latents).sum()}")
820-
logger.info(f"Step {i}: latents shape change: {latents_before.shape} -> {latents.shape}")
821-
822-
# Break if NaN detected
823-
if torch.isnan(latents).sum() > 0:
824-
logger.error(f"NaN detected at step {i}, breaking denoising loop")
825-
break
826-
827859
progress_bar.update()
828860

829861
# Update batch with final latents

0 commit comments

Comments
 (0)