2222from  ...callbacks  import  MultiPipelineCallbacks , PipelineCallback 
2323from  ...image_processor  import  PipelineImageInput 
2424from  ...models  import  AutoencoderKLWan , CosmosTransformer3DModel 
25- from  ...schedulers  import  EDMEulerScheduler 
25+ from  ...schedulers  import  FlowMatchEulerDiscreteScheduler 
2626from  ...utils  import  is_cosmos_guardrail_available , is_torch_xla_available , logging , replace_example_docstring 
2727from  ...utils .torch_utils  import  randn_tensor 
2828from  ...video_processor  import  VideoProcessor 
@@ -153,7 +153,7 @@ def retrieve_latents(
153153
154154class  Cosmos2VideoToWorldPipeline (DiffusionPipeline ):
155155    r""" 
156-     Pipeline for text -to-image  generation using [Cosmos](https://github.com/NVIDIA/Cosmos ). 
156+     Pipeline for video -to-world  generation using [Cosmos Predict2 ](https://github.com/nvidia-cosmos/cosmos-predict2 ). 
157157
158158    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 
159159    implemented for all pipelines (downloading, saving, running on a particular device, etc.). 
@@ -168,7 +168,7 @@ class Cosmos2VideoToWorldPipeline(DiffusionPipeline):
168168            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). 
169169        transformer ([`CosmosTransformer3DModel`]): 
170170            Conditional Transformer to denoise the encoded image latents. 
171-         scheduler ([`EDMEulerScheduler `]): 
171+         scheduler ([`FlowMatchEulerDiscreteScheduler `]): 
172172            A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 
173173        vae ([`AutoencoderKLWan`]): 
174174            Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. 
@@ -185,7 +185,7 @@ def __init__(
185185        tokenizer : T5TokenizerFast ,
186186        transformer : CosmosTransformer3DModel ,
187187        vae : AutoencoderKLWan ,
188-         scheduler : EDMEulerScheduler ,
188+         scheduler : FlowMatchEulerDiscreteScheduler ,
189189        safety_checker : CosmosSafetyChecker  =  None ,
190190    ):
191191        super ().__init__ ()
@@ -206,6 +206,18 @@ def __init__(
206206        self .vae_scale_factor_spatial  =  2  **  len (self .vae .temperal_downsample ) if  getattr (self , "vae" , None ) else  8 
207207        self .video_processor  =  VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
208208
209+         self .sigma_max  =  80.0 
210+         self .sigma_min  =  0.002 
211+         self .sigma_data  =  1.0 
212+         self .final_sigmas_type  =  "sigma_min" 
213+         if  self .scheduler  is  not None :
214+             self .scheduler .register_to_config (
215+                 sigma_max = self .sigma_max ,
216+                 sigma_min = self .sigma_min ,
217+                 sigma_data = self .sigma_data ,
218+                 final_sigmas_type = self .final_sigmas_type ,
219+             )
220+ 
209221    # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds 
210222    def  _get_t5_prompt_embeds (
211223        self ,
@@ -340,7 +352,7 @@ def prepare_latents(
340352        num_channels_latents : 16 ,
341353        height : int  =  704 ,
342354        width : int  =  1280 ,
343-         num_frames : int  =  77 ,
355+         num_frames : int  =  93 ,
344356        do_classifier_free_guidance : bool  =  True ,
345357        dtype : Optional [torch .dtype ] =  None ,
346358        device : Optional [torch .device ] =  None ,
@@ -472,7 +484,7 @@ def __call__(
472484        negative_prompt : Optional [Union [str , List [str ]]] =  None ,
473485        height : int  =  704 ,
474486        width : int  =  1280 ,
475-         num_frames : int  =  77 ,
487+         num_frames : int  =  93 ,
476488        num_inference_steps : int  =  35 ,
477489        guidance_scale : float  =  7.0 ,
478490        fps : int  =  16 ,
@@ -505,7 +517,7 @@ def __call__(
505517                The height in pixels of the generated image. 
506518            width (`int`, defaults to `1280`): 
507519                The width in pixels of the generated image. 
508-             num_frames (`int`, defaults to `77 `): 
520+             num_frames (`int`, defaults to `93 `): 
509521                The number of frames in the generated video. 
510522            num_inference_steps (`int`, defaults to `35`): 
511523                The number of denoising steps. More denoising steps usually lead to a higher quality image at the 
@@ -616,7 +628,13 @@ def __call__(
616628        )
617629
618630        # 4. Prepare timesteps 
619-         timesteps , num_inference_steps  =  retrieve_timesteps (self .scheduler , num_inference_steps , device )
631+         sigmas_dtype  =  torch .float32  if  torch .backends .mps .is_available () else  torch .float64 
632+         sigmas  =  torch .linspace (0 , 1 , num_inference_steps , dtype = sigmas_dtype )
633+         timesteps , num_inference_steps  =  retrieve_timesteps (self .scheduler , device = device , sigmas = sigmas )
634+         if  self .scheduler .config .final_sigmas_type  ==  "sigma_min" :
635+             # Replace the last sigma (which is zero) with the minimum sigma value 
636+             timesteps [- 1 ] =  timesteps [- 2 ]
637+             self .scheduler .sigmas [- 1 ] =  self .scheduler .sigmas [- 2 ]
620638
621639        # 5. Prepare latent variables 
622640        vae_dtype  =  self .vae .dtype 
@@ -651,7 +669,7 @@ def __call__(
651669
652670        padding_mask  =  latents .new_zeros (1 , 1 , height , width , dtype = transformer_dtype )
653671        sigma_conditioning  =  torch .tensor (sigma_conditioning , dtype = torch .float32 , device = device )
654-         t_conditioning  =  self . scheduler . precondition_noise (sigma_conditioning )
672+         t_conditioning  =  sigma_conditioning   /   (sigma_conditioning   +   1 )
655673
656674        # 6. Denoising loop 
657675        num_warmup_steps  =  len (timesteps ) -  num_inference_steps  *  self .scheduler .order 
@@ -663,12 +681,15 @@ def __call__(
663681                    continue 
664682
665683                self ._current_timestep  =  t 
666-                 timestep  =  t .view (1 , 1 , 1 , 1 , 1 ).expand (
667-                     latents .size (0 ), - 1 , latents .size (2 ), - 1 , - 1 
668-                 )  # [B, 1, T, 1, 1] 
669684                current_sigma  =  self .scheduler .sigmas [i ]
670685
671-                 cond_latent  =  self .scheduler .scale_model_input (latents , t )
686+                 current_t  =  current_sigma  /  (current_sigma  +  1 )
687+                 c_in  =  1  -  current_t 
688+                 c_skip  =  1  -  current_t 
689+                 c_out  =  - current_t 
690+                 timestep  =  current_t .expand (latents .shape [0 ]).to (transformer_dtype )  # [B, 1, T, 1, 1] 
691+ 
692+                 cond_latent  =  latents  *  c_in 
672693                cond_latent  =  cond_indicator  *  conditioning_latents  +  (1  -  cond_indicator ) *  cond_latent 
673694                cond_latent  =  cond_latent .to (transformer_dtype )
674695                cond_timestep  =  cond_indicator  *  t_conditioning  +  (1  -  cond_indicator ) *  timestep 
@@ -683,11 +704,11 @@ def __call__(
683704                    padding_mask = padding_mask ,
684705                    return_dict = False ,
685706                )[0 ]
686-                 noise_pred  =  self . scheduler . precondition_outputs ( latents ,  noise_pred ,  current_sigma )
707+                 noise_pred  =  ( c_skip   *   latents   +   c_out   *   noise_pred . float ()). to ( transformer_dtype )
687708                noise_pred  =  cond_indicator  *  conditioning_latents  +  (1  -  cond_indicator ) *  noise_pred 
688709
689710                if  self .do_classifier_free_guidance :
690-                     uncond_latent  =  self . scheduler . scale_model_input ( latents ,  t ) 
711+                     uncond_latent  =  latents   *   c_in 
691712                    uncond_latent  =  uncond_indicator  *  unconditioning_latents  +  (1  -  uncond_indicator ) *  uncond_latent 
692713                    uncond_latent  =  uncond_latent .to (transformer_dtype )
693714                    uncond_timestep  =  uncond_indicator  *  t_conditioning  +  (1  -  uncond_indicator ) *  timestep 
@@ -702,15 +723,14 @@ def __call__(
702723                        padding_mask = padding_mask ,
703724                        return_dict = False ,
704725                    )[0 ]
705-                     noise_pred_uncond  =  self . scheduler . precondition_outputs ( latents ,  noise_pred_uncond ,  current_sigma )
726+                     noise_pred_uncond  =  ( c_skip   *   latents   +   c_out   *   noise_pred_uncond . float ()). to ( transformer_dtype )
706727                    noise_pred_uncond  =  (
707728                        uncond_indicator  *  unconditioning_latents  +  (1  -  uncond_indicator ) *  noise_pred_uncond 
708729                    )
709730                    noise_pred  =  noise_pred  +  self .guidance_scale  *  (noise_pred  -  noise_pred_uncond )
710731
711-                 latents  =  self .scheduler .step (
712-                     noise_pred , t , latents , pred_original_sample = noise_pred , return_dict = False 
713-                 )[0 ]
732+                 noise_pred  =  (latents  -  noise_pred ) /  current_sigma 
733+                 latents  =  self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
714734
715735                if  callback_on_step_end  is  not None :
716736                    callback_kwargs  =  {}
0 commit comments