@@ -600,6 +600,166 @@ def verify_output(self, batch: ForwardBatch,
600600 return result
601601
602602
603+ class CosmosDenoisingStage (PipelineStage ):
604+ """
605+ Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.
606+
607+ This stage implements the diffusers-compatible Cosmos denoising process with velocity prediction,
608+ classifier-free guidance, and conditional video generation support.
609+ Compatible with Hugging Face Cosmos models.
610+ """
611+
612+ def __init__ (self ,
613+ transformer ,
614+ scheduler ,
615+ pipeline = None ) -> None :
616+ super ().__init__ ()
617+ self .transformer = transformer
618+ self .scheduler = scheduler # FlowMatchEulerDiscreteScheduler
619+ self .pipeline = weakref .ref (pipeline ) if pipeline else None
620+
621+ def forward (
622+ self ,
623+ batch : ForwardBatch ,
624+ fastvideo_args : FastVideoArgs ,
625+ ) -> ForwardBatch :
626+ """
627+ Run the diffusers-style Cosmos denoising loop.
628+
629+ Args:
630+ batch: The current batch information.
631+ fastvideo_args: The inference arguments.
632+
633+ Returns:
634+ The batch with denoised latents.
635+ """
636+ pipeline = self .pipeline () if self .pipeline else None
637+ if not fastvideo_args .model_loaded ["transformer" ]:
638+ loader = TransformerLoader ()
639+ self .transformer = loader .load (
640+ fastvideo_args .model_paths ["transformer" ], fastvideo_args )
641+ if pipeline :
642+ pipeline .add_module ("transformer" , self .transformer )
643+ fastvideo_args .model_loaded ["transformer" ] = True
644+
645+ # Setup precision and autocast settings
646+ target_dtype = torch .bfloat16
647+ autocast_enabled = (target_dtype != torch .float32
648+ ) and not fastvideo_args .disable_autocast
649+
650+ # Get latents and setup
651+ latents = batch .latents
652+ num_inference_steps = batch .num_inference_steps
653+ guidance_scale = batch .guidance_scale
654+
655+ # Setup scheduler with sigma schedule
656+ sigmas_dtype = torch .float32 if torch .backends .mps .is_available () else torch .float64
657+ sigmas = torch .linspace (0 , 1 , num_inference_steps , dtype = sigmas_dtype )
658+ timesteps = torch .arange (num_inference_steps , device = latents .device , dtype = torch .long )
659+ self .scheduler .set_timesteps (device = latents .device , sigmas = sigmas )
660+
661+ # Initialize with maximum noise
662+ latents = torch .randn_like (latents , dtype = torch .float32 ) * self .scheduler .config .sigma_max
663+
664+ # Prepare conditional frame handling (if needed)
665+ # This would be implemented based on batch.conditioning_latents or similar
666+
667+ # Sampling loop
668+ with self .progress_bar (total = num_inference_steps ) as progress_bar :
669+ for i , t in enumerate (timesteps ):
670+ # Skip if interrupted
671+ if hasattr (self , 'interrupt' ) and self .interrupt :
672+ continue
673+
674+ # Get current sigma and preconditioning coefficients
675+ current_sigma = self .scheduler .sigmas [i ]
676+ current_t = current_sigma / (current_sigma + 1 )
677+ c_in = 1 - current_t
678+ c_skip = 1 - current_t
679+ c_out = - current_t
680+
681+ # Prepare timestep tensor
682+ timestep = current_t .view (1 , 1 , 1 , 1 , 1 ).expand (
683+ latents .size (0 ), - 1 , latents .size (2 ), - 1 , - 1
684+ )
685+
686+ with torch .autocast (device_type = "cuda" ,
687+ dtype = target_dtype ,
688+ enabled = autocast_enabled ):
689+
690+ # Conditional forward pass
691+ cond_latent = latents * c_in
692+ # Add conditional frame handling here if needed:
693+ # cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent
694+
695+ cond_velocity = self .transformer (
696+ hidden_states = cond_latent .to (target_dtype ),
697+ timestep = timestep .to (target_dtype ),
698+ encoder_hidden_states = batch .prompt_embeds [0 ].to (target_dtype ),
699+ return_dict = False ,
700+ )[0 ]
701+
702+ # Apply preconditioning
703+ cond_pred = (c_skip * latents + c_out * cond_velocity .float ()).to (target_dtype )
704+
705+ # Classifier-free guidance
706+ if batch .do_classifier_free_guidance and batch .negative_prompt_embeds is not None :
707+ uncond_latent = latents * c_in
708+
709+ uncond_velocity = self .transformer (
710+ hidden_states = uncond_latent .to (target_dtype ),
711+ timestep = timestep .to (target_dtype ),
712+ encoder_hidden_states = batch .negative_prompt_embeds [0 ].to (target_dtype ),
713+ return_dict = False ,
714+ )[0 ]
715+
716+ uncond_pred = (c_skip * latents + c_out * uncond_velocity .float ()).to (target_dtype )
717+
718+ # Apply guidance
719+ velocity_pred = cond_pred + guidance_scale * (cond_pred - uncond_pred )
720+ else :
721+ velocity_pred = cond_pred
722+
723+ # Convert velocity to noise for scheduler
724+ noise_pred = (latents - velocity_pred ) / current_sigma
725+
726+ # Standard scheduler step
727+ latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
728+
729+ progress_bar .update ()
730+
731+ # Update batch with final latents
732+ batch .latents = latents
733+
734+ return batch
735+
736+ def verify_input (self , batch : ForwardBatch ,
737+ fastvideo_args : FastVideoArgs ) -> VerificationResult :
738+ """Verify Cosmos denoising stage inputs."""
739+ result = VerificationResult ()
740+ result .add_check ("latents" , batch .latents ,
741+ [V .is_tensor , V .with_dims (5 )])
742+ result .add_check ("prompt_embeds" , batch .prompt_embeds , V .list_not_empty )
743+ result .add_check ("num_inference_steps" , batch .num_inference_steps ,
744+ V .positive_int )
745+ result .add_check ("guidance_scale" , batch .guidance_scale ,
746+ V .positive_float )
747+ result .add_check ("do_classifier_free_guidance" ,
748+ batch .do_classifier_free_guidance , V .bool_value )
749+ result .add_check (
750+ "negative_prompt_embeds" , batch .negative_prompt_embeds , lambda x :
751+ not batch .do_classifier_free_guidance or V .list_not_empty (x ))
752+ return result
753+
754+ def verify_output (self , batch : ForwardBatch ,
755+ fastvideo_args : FastVideoArgs ) -> VerificationResult :
756+ """Verify Cosmos denoising stage outputs."""
757+ result = VerificationResult ()
758+ result .add_check ("latents" , batch .latents ,
759+ [V .is_tensor , V .with_dims (5 )])
760+ return result
761+
762+
603763class DmdDenoisingStage (DenoisingStage ):
604764 """
605765 Denoising stage for DMD.
0 commit comments