@@ -600,7 +600,7 @@ def verify_output(self, batch: ForwardBatch,
600600 return result
601601
602602
603- class CosmosDenoisingStage (PipelineStage ):
603+ class CosmosDenoisingStage (DenoisingStage ):
604604 """
605605 Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.
606606
@@ -613,10 +613,8 @@ def __init__(self,
613613 transformer ,
614614 scheduler ,
615615 pipeline = None ) -> None :
616- super ().__init__ ()
617- self .transformer = transformer
618- self .scheduler = scheduler # FlowMatchEulerDiscreteScheduler
619- self .pipeline = weakref .ref (pipeline ) if pipeline else None
616+ super ().__init__ (transformer , scheduler , pipeline )
617+ # FlowMatchEulerDiscreteScheduler is already set by parent
620618
621619 def forward (
622620 self ,
@@ -659,7 +657,7 @@ def forward(
659657 self .scheduler .set_timesteps (device = latents .device , sigmas = sigmas )
660658
661659 # Initialize with maximum noise
662- latents = torch .randn_like (latents , dtype = torch .float32 ) * self .scheduler .config . sigma_max
660+ latents = torch .randn_like (latents , dtype = torch .float32 ) * self .scheduler .sigma_max
663661
664662 # Prepare conditional frame handling (if needed)
665663 # This would be implemented based on batch.conditioning_latents or similar
@@ -692,12 +690,17 @@ def forward(
692690 # Add conditional frame handling here if needed:
693691 # cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent
694692
695- cond_velocity = self .transformer (
696- hidden_states = cond_latent .to (target_dtype ),
697- timestep = timestep .to (target_dtype ),
698- encoder_hidden_states = batch .prompt_embeds [0 ].to (target_dtype ),
699- return_dict = False ,
700- )[0 ]
693+ with set_forward_context (
694+ current_timestep = i ,
695+ attn_metadata = None , # TODO: implement attention metadata if needed
696+ forward_batch = batch ,
697+ ):
698+ cond_velocity = self .transformer (
699+ hidden_states = cond_latent .to (target_dtype ),
700+ timestep = timestep .to (target_dtype ),
701+ encoder_hidden_states = batch .prompt_embeds [0 ].to (target_dtype ),
702+ return_dict = False ,
703+ )[0 ]
701704
702705 # Apply preconditioning
703706 cond_pred = (c_skip * latents + c_out * cond_velocity .float ()).to (target_dtype )
@@ -706,12 +709,17 @@ def forward(
706709 if batch .do_classifier_free_guidance and batch .negative_prompt_embeds is not None :
707710 uncond_latent = latents * c_in
708711
709- uncond_velocity = self .transformer (
710- hidden_states = uncond_latent .to (target_dtype ),
711- timestep = timestep .to (target_dtype ),
712- encoder_hidden_states = batch .negative_prompt_embeds [0 ].to (target_dtype ),
713- return_dict = False ,
714- )[0 ]
712+ with set_forward_context (
713+ current_timestep = i ,
714+ attn_metadata = None , # TODO: implement attention metadata if needed
715+ forward_batch = batch ,
716+ ):
717+ uncond_velocity = self .transformer (
718+ hidden_states = uncond_latent .to (target_dtype ),
719+ timestep = timestep .to (target_dtype ),
720+ encoder_hidden_states = batch .negative_prompt_embeds [0 ].to (target_dtype ),
721+ return_dict = False ,
722+ )[0 ]
715723
716724 uncond_pred = (c_skip * latents + c_out * uncond_velocity .float ()).to (target_dtype )
717725
0 commit comments