@@ -149,20 +149,32 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
149149 A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
150150 vae ([`AutoencoderKLWan`]):
151151 Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
152+ transformer_2 ([`WanTransformer3DModel`], *optional*):
153+ Conditional Transformer to denoise the input latents during the low-noise stage.
154+ In two-stage denoising, `transformer` handles high-noise stages
155+ and `transformer_2` handles low-noise stages. If not provided, only `transformer` is used.
156+ boundary_ratio (`float`, *optional*, defaults to `None`):
157+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
158+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`.
159+ When provided, `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < boundary_timestep.
160+ If `None`, only `transformer` is used for the entire denoising process.
152161 """
153162
154- model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
163+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2-> vae"
155164 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
165+ _optional_components = ["transformer_2" , "image_encoder" , "image_processor" ]
156166
157167 def __init__ (
158168 self ,
159169 tokenizer : AutoTokenizer ,
160170 text_encoder : UMT5EncoderModel ,
161- image_encoder : CLIPVisionModel ,
162- image_processor : CLIPImageProcessor ,
163171 transformer : WanTransformer3DModel ,
164172 vae : AutoencoderKLWan ,
165173 scheduler : FlowMatchEulerDiscreteScheduler ,
174+ image_processor : CLIPImageProcessor = None ,
175+ image_encoder : CLIPVisionModel = None ,
176+ transformer_2 : WanTransformer3DModel = None ,
177+ boundary_ratio : Optional [float ] = None ,
166178 ):
167179 super ().__init__ ()
168180
@@ -174,7 +186,9 @@ def __init__(
174186 transformer = transformer ,
175187 scheduler = scheduler ,
176188 image_processor = image_processor ,
189+ transformer_2 = transformer_2 ,
177190 )
191+ self .register_to_config (boundary_ratio = boundary_ratio )
178192
179193 self .vae_scale_factor_temporal = 2 ** sum (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 4
180194 self .vae_scale_factor_spatial = 2 ** len (self .vae .temperal_downsample ) if getattr (self , "vae" , None ) else 8
@@ -325,6 +339,7 @@ def check_inputs(
325339 negative_prompt_embeds = None ,
326340 image_embeds = None ,
327341 callback_on_step_end_tensor_inputs = None ,
342+ guidance_scale_2 = None ,
328343 ):
329344 if image is not None and image_embeds is not None :
330345 raise ValueError (
@@ -368,6 +383,12 @@ def check_inputs(
368383 ):
369384 raise ValueError (f"`negative_prompt` has to be of type `str` or `list` but is { type (negative_prompt )} " )
370385
386+ if self .config .boundary_ratio is None and guidance_scale_2 is not None :
387+ raise ValueError ("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None." )
388+
389+ if self .config .boundary_ratio is not None and image_embeds is not None :
390+ raise ValueError ("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured." )
391+
371392 def prepare_latents (
372393 self ,
373394 image : PipelineImageInput ,
@@ -483,6 +504,7 @@ def __call__(
483504 num_frames : int = 81 ,
484505 num_inference_steps : int = 50 ,
485506 guidance_scale : float = 5.0 ,
507+ guidance_scale_2 : Optional [float ] = None ,
486508 num_videos_per_prompt : Optional [int ] = 1 ,
487509 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
488510 latents : Optional [torch .Tensor ] = None ,
@@ -527,6 +549,9 @@ def __call__(
527549 of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
528550 `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
529551 the text `prompt`, usually at the expense of lower image quality.
552+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
553+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's `boundary_ratio` is not None,
554+ uses the same value as `guidance_scale`. Only used when `transformer_2` and the pipeline's `boundary_ratio` are not None.
530555 num_videos_per_prompt (`int`, *optional*, defaults to 1):
531556 The number of images to generate per prompt.
532557 generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -589,6 +614,7 @@ def __call__(
589614 negative_prompt_embeds ,
590615 image_embeds ,
591616 callback_on_step_end_tensor_inputs ,
617+ guidance_scale_2 ,
592618 )
593619
594620 if num_frames % self .vae_scale_factor_temporal != 1 :
@@ -598,7 +624,12 @@ def __call__(
598624 num_frames = num_frames // self .vae_scale_factor_temporal * self .vae_scale_factor_temporal + 1
599625 num_frames = max (num_frames , 1 )
600626
627+
628+ if self .config .boundary_ratio is not None and guidance_scale_2 is None :
629+ guidance_scale_2 = guidance_scale
630+
601631 self ._guidance_scale = guidance_scale
632+ self ._guidance_scale_2 = guidance_scale_2
602633 self ._attention_kwargs = attention_kwargs
603634 self ._current_timestep = None
604635 self ._interrupt = False
@@ -631,13 +662,15 @@ def __call__(
631662 if negative_prompt_embeds is not None :
632663 negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
633664
634- if image_embeds is None :
635- if last_image is None :
636- image_embeds = self .encode_image (image , device )
637- else :
638- image_embeds = self .encode_image ([image , last_image ], device )
639- image_embeds = image_embeds .repeat (batch_size , 1 , 1 )
640- image_embeds = image_embeds .to (transformer_dtype )
665+
666+ if self .config .boundary_ratio is None :
667+ if image_embeds is None :
668+ if last_image is None :
669+ image_embeds = self .encode_image (image , device )
670+ else :
671+ image_embeds = self .encode_image ([image , last_image ], device )
672+ image_embeds = image_embeds .repeat (batch_size , 1 , 1 )
673+ image_embeds = image_embeds .to (transformer_dtype )
641674
642675 # 4. Prepare timesteps
643676 self .scheduler .set_timesteps (num_inference_steps , device = device )
@@ -668,16 +701,33 @@ def __call__(
668701 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
669702 self ._num_timesteps = len (timesteps )
670703
704+ if self .config .boundary_ratio is not None :
705+ boundary_timestep = self .config .boundary_ratio * self .scheduler .config .num_train_timesteps
706+ else :
707+ boundary_timestep = None
708+
709+
671710 with self .progress_bar (total = num_inference_steps ) as progress_bar :
672711 for i , t in enumerate (timesteps ):
673712 if self .interrupt :
674713 continue
675714
676715 self ._current_timestep = t
716+
717+ if boundary_timestep is None or t >= boundary_timestep :
718+ # wan2.1 or high-noise stage in wan2.2
719+ current_model = self .transformer
720+ current_guidance_scale = guidance_scale
721+ else :
722+ # low-noise stage in wan2.2
723+ current_model = self .transformer_2
724+ current_guidance_scale = guidance_scale_2
725+
726+
677727 latent_model_input = torch .cat ([latents , condition ], dim = 1 ).to (transformer_dtype )
678728 timestep = t .expand (latents .shape [0 ])
679729
680- noise_pred = self . transformer (
730+ noise_pred = current_model (
681731 hidden_states = latent_model_input ,
682732 timestep = timestep ,
683733 encoder_hidden_states = prompt_embeds ,
@@ -687,15 +737,15 @@ def __call__(
687737 )[0 ]
688738
689739 if self .do_classifier_free_guidance :
690- noise_uncond = self . transformer (
740+ noise_uncond = current_model (
691741 hidden_states = latent_model_input ,
692742 timestep = timestep ,
693743 encoder_hidden_states = negative_prompt_embeds ,
694744 encoder_hidden_states_image = image_embeds ,
695745 attention_kwargs = attention_kwargs ,
696746 return_dict = False ,
697747 )[0 ]
698- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
748+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond )
699749
700750 # compute the previous noisy sample x_t -> x_t-1
701751 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments