@@ -152,16 +152,26 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
152152 text_encoder ([`T5EncoderModel`]):
153153 [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
154154 the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
155- transformer ([`WanTransformer3DModel `]):
155+ transformer ([`WanVACETransformer3DModel `]):
156156 Conditional Transformer to denoise the input latents.
157+ transformer_2 ([`WanVACETransformer3DModel`], *optional*):
158+ Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
159+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
160+ `transformer` is used.
157161 scheduler ([`UniPCMultistepScheduler`]):
158162 A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
159163 vae ([`AutoencoderKLWan`]):
160164 Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
165+ boundary_ratio (`float`, *optional*, defaults to `None`):
166+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
167+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
168+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
169+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
161170 """
162171
163172 model_cpu_offload_seq = "text_encoder->transformer->vae"
164173 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
174+ _optional_components = ["transformer_2" ]
165175
166176 def __init__ (
167177 self ,
@@ -170,6 +180,8 @@ def __init__(
170180 transformer : WanVACETransformer3DModel ,
171181 vae : AutoencoderKLWan ,
172182 scheduler : FlowMatchEulerDiscreteScheduler ,
183+ transformer_2 : WanVACETransformer3DModel = None ,
184+ boundary_ratio : Optional [float ] = None ,
173185 ):
174186 super ().__init__ ()
175187
@@ -178,9 +190,10 @@ def __init__(
178190 text_encoder = text_encoder ,
179191 tokenizer = tokenizer ,
180192 transformer = transformer ,
193+ transformer_2 = transformer_2 ,
181194 scheduler = scheduler ,
182195 )
183-
196+ self . register_to_config ( boundary_ratio = boundary_ratio )
184197 self .vae_scale_factor_temporal = 2 ** sum (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
185198 self .vae_scale_factor_spatial = 2 ** len (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
186199 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
@@ -321,6 +334,7 @@ def check_inputs(
321334 video = None ,
322335 mask = None ,
323336 reference_images = None ,
337+ guidance_scale_2 = None ,
324338 ):
325339 base = self .vae_scale_factor_spatial * self .transformer .config .patch_size [1 ]
326340 if height % base != 0 or width % base != 0 :
@@ -332,6 +346,8 @@ def check_inputs(
332346 raise ValueError (
333347 f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
334348 )
349+ if self .config .boundary_ratio is None and guidance_scale_2 is not None :
350+ raise ValueError ("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None." )
335351
336352 if prompt is not None and prompt_embeds is not None :
337353 raise ValueError (
@@ -667,6 +683,7 @@ def __call__(
667683 num_frames : int = 81 ,
668684 num_inference_steps : int = 50 ,
669685 guidance_scale : float = 5.0 ,
686+ guidance_scale_2 : Optional [float ] = None ,
670687 num_videos_per_prompt : Optional [int ] = 1 ,
671688 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
672689 latents : Optional [torch .Tensor ] = None ,
@@ -728,6 +745,10 @@ def __call__(
728745 Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
729746 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
730747 usually at the expense of lower image quality.
748+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
749+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
750+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
751+ and the pipeline's `boundary_ratio` are not None.
731752 num_videos_per_prompt (`int`, *optional*, defaults to 1):
732753 The number of images to generate per prompt.
733754 generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -793,6 +814,7 @@ def __call__(
793814 video ,
794815 mask ,
795816 reference_images ,
817+ guidance_scale_2 ,
796818 )
797819
798820 if num_frames % self .vae_scale_factor_temporal != 1 :
@@ -802,7 +824,11 @@ def __call__(
802824 num_frames = num_frames // self .vae_scale_factor_temporal * self .vae_scale_factor_temporal + 1
803825 num_frames = max (num_frames , 1 )
804826
827+ if self .config .boundary_ratio is not None and guidance_scale_2 is None :
828+ guidance_scale_2 = guidance_scale
829+
805830 self ._guidance_scale = guidance_scale
831+ self ._guidance_scale_2 = guidance_scale_2
806832 self ._attention_kwargs = attention_kwargs
807833 self ._current_timestep = None
808834 self ._interrupt = False
@@ -896,36 +922,53 @@ def __call__(
896922 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
897923 self ._num_timesteps = len (timesteps )
898924
925+ if self .config .boundary_ratio is not None :
926+ boundary_timestep = self .config .boundary_ratio * self .scheduler .config .num_train_timesteps
927+ else :
928+ boundary_timestep = None
929+
899930 with self .progress_bar (total = num_inference_steps ) as progress_bar :
900931 for i , t in enumerate (timesteps ):
901932 if self .interrupt :
902933 continue
903934
904935 self ._current_timestep = t
936+
937+ if boundary_timestep is None or t >= boundary_timestep :
938+ # wan2.1 or high-noise stage in wan2.2
939+ current_model = self .transformer
940+ current_guidance_scale = guidance_scale
941+ else :
942+ # low-noise stage in wan2.2
943+ current_model = self .transformer_2
944+ current_guidance_scale = guidance_scale_2
945+
905946 latent_model_input = latents .to (transformer_dtype )
906947 timestep = t .expand (latents .shape [0 ])
907948
908- noise_pred = self .transformer (
909- hidden_states = latent_model_input ,
910- timestep = timestep ,
911- encoder_hidden_states = prompt_embeds ,
912- control_hidden_states = conditioning_latents ,
913- control_hidden_states_scale = conditioning_scale ,
914- attention_kwargs = attention_kwargs ,
915- return_dict = False ,
916- )[0 ]
917-
918- if self .do_classifier_free_guidance :
919- noise_uncond = self .transformer (
949+ with current_model .cache_context ("cond" ):
950+ noise_pred = current_model (
920951 hidden_states = latent_model_input ,
921952 timestep = timestep ,
922- encoder_hidden_states = negative_prompt_embeds ,
953+ encoder_hidden_states = prompt_embeds ,
923954 control_hidden_states = conditioning_latents ,
924955 control_hidden_states_scale = conditioning_scale ,
925956 attention_kwargs = attention_kwargs ,
926957 return_dict = False ,
927958 )[0 ]
928- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
959+
960+ if self .do_classifier_free_guidance :
961+ with current_model .cache_context ("uncond" ):
962+ noise_uncond = current_model (
963+ hidden_states = latent_model_input ,
964+ timestep = timestep ,
965+ encoder_hidden_states = negative_prompt_embeds ,
966+ control_hidden_states = conditioning_latents ,
967+ control_hidden_states_scale = conditioning_scale ,
968+ attention_kwargs = attention_kwargs ,
969+ return_dict = False ,
970+ )[0 ]
971+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
929972
930973 # compute the previous noisy sample x_t -> x_t-1
931974 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments