@@ -624,7 +624,7 @@ class CosmosDenoisingStage(DenoisingStage):
624624 """
625625 Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.
626626
627- This stage implements the diffusers-compatible Cosmos denoising process with velocity prediction,
627+ This stage implements the diffusers-compatible Cosmos denoising process with noise prediction,
628628 classifier-free guidance, and conditional video generation support.
629629 Compatible with Hugging Face Cosmos models.
630630 """
@@ -671,15 +671,24 @@ def forward(
671671 guidance_scale = batch .guidance_scale
672672
673673
674- # Setup scheduler timesteps (let scheduler generate proper sigmas)
674+ # Setup scheduler timesteps like Diffusers does
675+ # Diffusers uses set_timesteps without custom sigmas, letting the scheduler generate them
675676 self .scheduler .set_timesteps (num_inference_steps , device = latents .device )
676677 timesteps = self .scheduler .timesteps
677678
678- # Initialize with maximum noise
679- latents = torch .randn_like (latents , dtype = torch .float32 ) * self .scheduler .sigma_max
679+ # Handle final sigmas like diffusers
680+ if hasattr (self .scheduler .config , 'final_sigmas_type' ) and self .scheduler .config .final_sigmas_type == "sigma_min" :
681+ if len (self .scheduler .sigmas ) > 1 :
682+ self .scheduler .sigmas [- 1 ] = self .scheduler .sigmas [- 2 ]
680683
681- # Prepare conditional frame handling (if needed)
682- # This would be implemented based on batch.conditioning_latents or similar
684+ # Debug: Log sigma information
685+ logger .info (f"CosmosDenoisingStage - Scheduler sigmas shape: { self .scheduler .sigmas .shape } " )
686+ logger .info (f"CosmosDenoisingStage - Sigma range: { self .scheduler .sigmas .min ():.6f} to { self .scheduler .sigmas .max ():.6f} " )
687+ logger .info (f"CosmosDenoisingStage - First few sigmas: { self .scheduler .sigmas [:5 ]} " )
688+
689+ # Get conditioning setup from batch (prepared by CosmosLatentPreparationStage)
690+ conditioning_latents = getattr (batch , 'conditioning_latents' , None )
691+ unconditioning_latents = conditioning_latents # Same for cosmos
683692
684693 # Sampling loop
685694 with self .progress_bar (total = num_inference_steps ) as progress_bar :
@@ -695,6 +704,11 @@ def forward(
695704 c_skip = 1 - current_t
696705 c_out = - current_t
697706
707+ # Debug: Log sigma and coefficients for first few steps
708+ if i < 3 :
709+ logger .info (f"Step { i } : current_sigma={ current_sigma :.6f} , current_t={ current_t :.6f} " )
710+ logger .info (f"Step { i } : c_in={ c_in :.6f} , c_skip={ c_skip :.6f} , c_out={ c_out :.6f} " )
711+
698712 # Prepare timestep tensor
699713 timestep = current_t .view (1 , 1 , 1 , 1 , 1 ).expand (
700714 latents .size (0 ), - 1 , latents .size (2 ), - 1 , - 1
@@ -706,8 +720,9 @@ def forward(
706720
707721 # Conditional forward pass
708722 cond_latent = latents * c_in
709- # Add conditional frame handling here if needed:
710- # cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent
723+ # Add conditional frame handling like diffusers
724+ if hasattr (batch , 'cond_indicator' ) and batch .cond_indicator is not None and conditioning_latents is not None :
725+ cond_latent = batch .cond_indicator * conditioning_latents + (1 - batch .cond_indicator ) * cond_latent
711726
712727 with set_forward_context (
713728 current_timestep = i ,
@@ -726,7 +741,16 @@ def forward(
726741 device = cond_latent .device , dtype = target_dtype )
727742
728743
729- cond_velocity = self .transformer (
744+ # Debug transformer inputs for first few steps
745+ if i < 3 :
746+ logger .info (f"Step { i } : Transformer inputs:" )
747+ logger .info (f" cond_latent shape: { cond_latent .shape } , sum: { cond_latent .float ().sum ().item ():.6f} " )
748+ logger .info (f" timestep shape: { timestep .shape } , values: { timestep .flatten ()[:5 ]} " )
749+ logger .info (f" prompt_embeds shape: { batch .prompt_embeds [0 ].shape } " )
750+ logger .info (f" condition_mask shape: { condition_mask .shape if condition_mask is not None else None } " )
751+ logger .info (f" padding_mask shape: { padding_mask .shape } " )
752+
753+ noise_pred = self .transformer (
730754 hidden_states = cond_latent .to (target_dtype ),
731755 timestep = timestep .to (target_dtype ),
732756 encoder_hidden_states = batch .prompt_embeds [0 ].to (target_dtype ),
@@ -735,14 +759,14 @@ def forward(
735759 padding_mask = padding_mask ,
736760 return_dict = False ,
737761 )[0 ]
738- sum_value = cond_velocity .float ().sum ().item ()
739- logger .info (f"CosmosDenoisingStage: step { i } , cond_velocity sum = { sum_value :.6f} " )
762+ sum_value = noise_pred .float ().sum ().item ()
763+ logger .info (f"CosmosDenoisingStage: step { i } , noise_pred sum = { sum_value :.6f} " )
740764 # Write to output file
741765 with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
742- f .write (f"CosmosDenoisingStage: step { i } , cond_velocity sum = { sum_value :.6f} \n " )
766+ f .write (f"CosmosDenoisingStage: step { i } , noise_pred sum = { sum_value :.6f} \n " )
743767
744- # Apply preconditioning and conditional masking
745- cond_pred = (c_skip * latents + c_out * cond_velocity .float ()).to (target_dtype )
768+ # Apply preconditioning exactly like diffusers
769+ cond_pred = (c_skip * latents + c_out * noise_pred .float ()).to (target_dtype )
746770
747771 # Apply conditional indicator masking (from CosmosLatentPreparationStage)
748772 if hasattr (batch , 'cond_indicator' ) and batch .cond_indicator is not None :
@@ -761,7 +785,13 @@ def forward(
761785 # Use uncond_mask for unconditional pass if available
762786 uncond_condition_mask = batch .uncond_mask .to (target_dtype ) if hasattr (batch , 'uncond_mask' ) and batch .uncond_mask is not None else condition_mask
763787
764- uncond_velocity = self .transformer (
788+ # Debug unconditional transformer inputs for first few steps
789+ if i < 3 :
790+ logger .info (f"Step { i } : Uncond transformer inputs:" )
791+ logger .info (f" uncond_latent sum: { uncond_latent .float ().sum ().item ():.6f} " )
792+ logger .info (f" negative_prompt_embeds shape: { batch .negative_prompt_embeds [0 ].shape } " )
793+
794+ noise_pred_uncond = self .transformer (
765795 hidden_states = uncond_latent .to (target_dtype ),
766796 timestep = timestep .to (target_dtype ),
767797 encoder_hidden_states = batch .negative_prompt_embeds [0 ].to (target_dtype ),
@@ -770,60 +800,62 @@ def forward(
770800 padding_mask = padding_mask ,
771801 return_dict = False ,
772802 )[0 ]
773- sum_value = uncond_velocity .float ().sum ().item ()
774- logger .info (f"CosmosDenoisingStage: step { i } , uncond_velocity sum = { sum_value :.6f} " )
803+ sum_value = noise_pred_uncond .float ().sum ().item ()
804+ logger .info (f"CosmosDenoisingStage: step { i } , noise_pred_uncond sum = { sum_value :.6f} " )
775805 # Write to output file
776806 with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
777- f .write (f"CosmosDenoisingStage: step { i } , uncond_velocity sum = { sum_value :.6f} \n " )
807+ f .write (f"CosmosDenoisingStage: step { i } , noise_pred_uncond sum = { sum_value :.6f} \n " )
778808
779- uncond_pred = (c_skip * latents + c_out * uncond_velocity .float ()).to (target_dtype )
809+ uncond_pred = (c_skip * latents + c_out * noise_pred_uncond .float ()).to (target_dtype )
780810
781- # Apply conditional indicator masking for unconditional prediction
782- if hasattr (batch , 'uncond_indicator' ) and batch .uncond_indicator is not None :
783- unconditioning_latents = conditioning_latents # Same as conditioning for cosmos
811+ # Apply conditional indicator masking for unconditional prediction like diffusers
812+ if hasattr (batch , 'uncond_indicator' ) and batch .uncond_indicator is not None and unconditioning_latents is not None :
784813 uncond_pred = batch .uncond_indicator * unconditioning_latents + (1 - batch .uncond_indicator ) * uncond_pred
785814
786- # Apply guidance
787- noise_pred = cond_pred + guidance_scale * (cond_pred - uncond_pred )
788- sum_value = noise_pred .float ().sum ().item ()
815+ # Apply guidance exactly like diffusers
816+ guidance_diff = cond_pred - uncond_pred
817+ final_pred = cond_pred + guidance_scale * guidance_diff
818+
819+ # Debug guidance computation
820+ if i < 3 : # Log first few steps
821+ logger .info (f"Step { i } : Guidance debug:" )
822+ logger .info (f" cond_pred sum = { cond_pred .float ().sum ().item ():.6f} " )
823+ logger .info (f" uncond_pred sum = { uncond_pred .float ().sum ().item ():.6f} " )
824+ logger .info (f" guidance_diff sum = { guidance_diff .float ().sum ().item ():.6f} " )
825+ logger .info (f" guidance_scale = { guidance_scale } " )
826+ logger .info (f" final_pred sum = { final_pred .float ().sum ().item ():.6f} " )
827+
828+ sum_value = final_pred .float ().sum ().item ()
789829 logger .info (f"CosmosDenoisingStage: step { i } , final noise_pred sum = { sum_value :.6f} " )
790830 # Write to output file
791831 with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
792832 f .write (f"CosmosDenoisingStage: step { i } , final noise_pred sum = { sum_value :.6f} \n " )
793833 else :
794- noise_pred = cond_pred
795-
796- # Debug: Check for NaN values before conversion
797- logger .info (f"Step { i } : Before conversion - latents NaN count: { torch .isnan (latents ).sum ()} " )
798- logger .info (f"Step { i } : Before conversion - noise_pred NaN count: { torch .isnan (noise_pred ).sum ()} " )
799- logger .info (f"Step { i } : current_sigma: { current_sigma } " )
834+ final_pred = cond_pred
835+ if i < 3 :
836+ logger .info (f"Step { i } : No CFG, using cond_pred directly: { final_pred .float ().sum ().item ():.6f} " )
800837
801- # Convert from velocity prediction to noise (Cosmos-specific conversion)
802- # Add small epsilon to prevent division by zero
803- current_sigma_safe = torch .clamp (current_sigma , min = 1e-8 )
804- noise_pred = (latents - noise_pred ) / current_sigma_safe
838+ # Convert to noise for scheduler step exactly like diffusers
839+ # Add safety check to prevent division by zero
840+ if current_sigma > 1e-8 :
841+ noise_for_scheduler = (latents - final_pred ) / current_sigma
842+ else :
843+ logger .warning (f"Step { i } : current_sigma too small ({ current_sigma } ), using final_pred directly" )
844+ noise_for_scheduler = final_pred
805845
806- # Debug: Check for NaN values after conversion
807- logger .info (f"Step { i } : After conversion - noise_pred NaN count: { torch .isnan (noise_pred ).sum ()} " )
846+ # Debug: Check for NaN values before scheduler step
847+ if torch .isnan (noise_for_scheduler ).sum () > 0 :
848+ logger .error (f"Step { i } : NaN detected in noise_for_scheduler, sum: { noise_for_scheduler .float ().sum ().item ()} " )
849+ logger .error (f"Step { i } : latents sum: { latents .float ().sum ().item ()} , final_pred sum: { final_pred .float ().sum ().item ()} , current_sigma: { current_sigma } " )
808850
809- # Standard scheduler step
810- latents_before = latents .clone ()
811- latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
851+ # Standard scheduler step like diffusers
852+ latents = self .scheduler .step (noise_for_scheduler , t , latents , return_dict = False )[0 ]
812853 sum_value = latents .float ().sum ().item ()
813854 logger .info (f"CosmosDenoisingStage: step { i } , updated latents sum = { sum_value :.6f} " )
814855 # Write to output file
815856 with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
816857 f .write (f"CosmosDenoisingStage: step { i } , updated latents sum = { sum_value :.6f} \n " )
817858
818- # Debug: Check for NaN values after scheduler step
819- logger .info (f"Step { i } : After scheduler - latents NaN count: { torch .isnan (latents ).sum ()} " )
820- logger .info (f"Step { i } : latents shape change: { latents_before .shape } -> { latents .shape } " )
821-
822- # Break if NaN detected
823- if torch .isnan (latents ).sum () > 0 :
824- logger .error (f"NaN detected at step { i } , breaking denoising loop" )
825- break
826-
827859 progress_bar .update ()
828860
829861 # Update batch with final latents
0 commit comments