@@ -152,16 +152,26 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
152152        text_encoder ([`T5EncoderModel`]): 
153153            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically 
154154            the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. 
155-         transformer ([`WanTransformer3DModel `]): 
155+         transformer ([`WanVACETransformer3DModel `]): 
156156            Conditional Transformer to denoise the input latents. 
157+         transformer_2 ([`WanVACETransformer3DModel`], *optional*): 
158+             Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising, 
159+             `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only 
160+             `transformer` is used. 
157161        scheduler ([`UniPCMultistepScheduler`]): 
158162            A scheduler to be used in combination with `transformer` to denoise the encoded image latents. 
159163        vae ([`AutoencoderKLWan`]): 
160164            Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. 
165+         boundary_ratio (`float`, *optional*, defaults to `None`): 
166+             Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. 
167+             The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, 
168+             `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < 
169+             boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. 
161170    """ 
162171
163172    model_cpu_offload_seq  =  "text_encoder->transformer->vae" 
164173    _callback_tensor_inputs  =  ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
174+     _optional_components  =  ["transformer_2" ]
165175
166176    def  __init__ (
167177        self ,
@@ -170,6 +180,8 @@ def __init__(
170180        transformer : WanVACETransformer3DModel ,
171181        vae : AutoencoderKLWan ,
172182        scheduler : FlowMatchEulerDiscreteScheduler ,
183+         transformer_2 : WanVACETransformer3DModel  =  None ,
184+         boundary_ratio : Optional [float ] =  None ,
173185    ):
174186        super ().__init__ ()
175187
@@ -178,9 +190,10 @@ def __init__(
178190            text_encoder = text_encoder ,
179191            tokenizer = tokenizer ,
180192            transformer = transformer ,
193+             transformer_2 = transformer_2 ,
181194            scheduler = scheduler ,
182195        )
183- 
196+          self . register_to_config ( boundary_ratio = boundary_ratio ) 
184197        self .vae_scale_factor_temporal  =  2  **  sum (self .vae .temperal_downsample ) if  getattr (self , "vae" , None ) else  4 
185198        self .vae_scale_factor_spatial  =  2  **  len (self .vae .temperal_downsample ) if  getattr (self , "vae" , None ) else  8 
186199        self .video_processor  =  VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
@@ -321,6 +334,7 @@ def check_inputs(
321334        video = None ,
322335        mask = None ,
323336        reference_images = None ,
337+         guidance_scale_2 = None ,
324338    ):
325339        base  =  self .vae_scale_factor_spatial  *  self .transformer .config .patch_size [1 ]
326340        if  height  %  base  !=  0  or  width  %  base  !=  0 :
@@ -332,6 +346,8 @@ def check_inputs(
332346            raise  ValueError (
333347                f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } { [k  for  k  in  callback_on_step_end_tensor_inputs  if  k  not  in self ._callback_tensor_inputs ]}  
334348            )
349+         if  self .config .boundary_ratio  is  None  and  guidance_scale_2  is  not None :
350+             raise  ValueError ("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None." )
335351
336352        if  prompt  is  not None  and  prompt_embeds  is  not None :
337353            raise  ValueError (
@@ -667,6 +683,7 @@ def __call__(
667683        num_frames : int  =  81 ,
668684        num_inference_steps : int  =  50 ,
669685        guidance_scale : float  =  5.0 ,
686+         guidance_scale_2 : Optional [float ] =  None ,
670687        num_videos_per_prompt : Optional [int ] =  1 ,
671688        generator : Optional [Union [torch .Generator , List [torch .Generator ]]] =  None ,
672689        latents : Optional [torch .Tensor ] =  None ,
@@ -728,6 +745,10 @@ def __call__(
728745                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 
729746                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 
730747                usually at the expense of lower image quality. 
748+             guidance_scale_2 (`float`, *optional*, defaults to `None`): 
749+                 Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's 
750+                 `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` 
751+                 and the pipeline's `boundary_ratio` are not None. 
731752            num_videos_per_prompt (`int`, *optional*, defaults to 1): 
732753                The number of images to generate per prompt. 
733754            generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 
@@ -793,6 +814,7 @@ def __call__(
793814            video ,
794815            mask ,
795816            reference_images ,
817+             guidance_scale_2 ,
796818        )
797819
798820        if  num_frames  %  self .vae_scale_factor_temporal  !=  1 :
@@ -802,7 +824,11 @@ def __call__(
802824            num_frames  =  num_frames  //  self .vae_scale_factor_temporal  *  self .vae_scale_factor_temporal  +  1 
803825        num_frames  =  max (num_frames , 1 )
804826
827+         if  self .config .boundary_ratio  is  not None  and  guidance_scale_2  is  None :
828+             guidance_scale_2  =  guidance_scale 
829+ 
805830        self ._guidance_scale  =  guidance_scale 
831+         self ._guidance_scale_2  =  guidance_scale_2 
806832        self ._attention_kwargs  =  attention_kwargs 
807833        self ._current_timestep  =  None 
808834        self ._interrupt  =  False 
@@ -896,36 +922,53 @@ def __call__(
896922        num_warmup_steps  =  len (timesteps ) -  num_inference_steps  *  self .scheduler .order 
897923        self ._num_timesteps  =  len (timesteps )
898924
925+         if  self .config .boundary_ratio  is  not None :
926+             boundary_timestep  =  self .config .boundary_ratio  *  self .scheduler .config .num_train_timesteps 
927+         else :
928+             boundary_timestep  =  None 
929+ 
899930        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
900931            for  i , t  in  enumerate (timesteps ):
901932                if  self .interrupt :
902933                    continue 
903934
904935                self ._current_timestep  =  t 
936+ 
937+                 if  boundary_timestep  is  None  or  t  >=  boundary_timestep :
938+                     # wan2.1 or high-noise stage in wan2.2 
939+                     current_model  =  self .transformer 
940+                     current_guidance_scale  =  guidance_scale 
941+                 else :
942+                     # low-noise stage in wan2.2 
943+                     current_model  =  self .transformer_2 
944+                     current_guidance_scale  =  guidance_scale_2 
945+ 
905946                latent_model_input  =  latents .to (transformer_dtype )
906947                timestep  =  t .expand (latents .shape [0 ])
907948
908-                 noise_pred  =  self .transformer (
909-                     hidden_states = latent_model_input ,
910-                     timestep = timestep ,
911-                     encoder_hidden_states = prompt_embeds ,
912-                     control_hidden_states = conditioning_latents ,
913-                     control_hidden_states_scale = conditioning_scale ,
914-                     attention_kwargs = attention_kwargs ,
915-                     return_dict = False ,
916-                 )[0 ]
917- 
918-                 if  self .do_classifier_free_guidance :
919-                     noise_uncond  =  self .transformer (
949+                 with  current_model .cache_context ("cond" ):
950+                     noise_pred  =  current_model (
920951                        hidden_states = latent_model_input ,
921952                        timestep = timestep ,
922-                         encoder_hidden_states = negative_prompt_embeds ,
953+                         encoder_hidden_states = prompt_embeds ,
923954                        control_hidden_states = conditioning_latents ,
924955                        control_hidden_states_scale = conditioning_scale ,
925956                        attention_kwargs = attention_kwargs ,
926957                        return_dict = False ,
927958                    )[0 ]
928-                     noise_pred  =  noise_uncond  +  guidance_scale  *  (noise_pred  -  noise_uncond )
959+ 
960+                 if  self .do_classifier_free_guidance :
961+                     with  current_model .cache_context ("uncond" ):
962+                         noise_uncond  =  current_model (
963+                             hidden_states = latent_model_input ,
964+                             timestep = timestep ,
965+                             encoder_hidden_states = negative_prompt_embeds ,
966+                             control_hidden_states = conditioning_latents ,
967+                             control_hidden_states_scale = conditioning_scale ,
968+                             attention_kwargs = attention_kwargs ,
969+                             return_dict = False ,
970+                         )[0 ]
971+                         noise_pred  =  noise_uncond  +  guidance_scale  *  (noise_pred  -  noise_uncond )
929972
930973                # compute the previous noisy sample x_t -> x_t-1 
931974                latents  =  self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments