From d162ace8e8ffe6484eec38755ef86e47949d5f0d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 12 Sep 2025 11:16:41 +0200 Subject: [PATCH 1/5] support Wan2.2-VACE-Fun-A14B --- scripts/convert_wan_to_diffusers.py | 33 ++++++++ .../pipelines/wan/pipeline_wan_vace.py | 75 +++++++++++++++---- 2 files changed, 92 insertions(+), 16 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 599c90be57ce..dadaab4670de 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -278,6 +278,29 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: } RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "Wan2.2-VACE-Fun-14B": + config = { + "model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35], + "vace_in_channels": 96, + }, + } + RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP elif model_type == "Wan2.2-I2V-14B-720p": config = { "model_id": "Wan-AI/Wan2.2-I2V-A14B", @@ -983,6 +1006,16 @@ def get_args(): vae=vae, scheduler=scheduler, ) + elif "Wan2.2-VACE" in args.model_type: + pipe = WanVACEPipeline( + transformer=transformer, + transformer_2=transformer_2, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + boundary_ratio=0.875, + ) else: pipe = WanPipeline( transformer=transformer, diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index 99e1f5116b85..9b3daa469631 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -152,16 +152,26 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - transformer ([`WanTransformer3DModel`]): + transformer ([`WanVACETransformer3DModel`]): Conditional Transformer to denoise the input latents. + transformer_2 ([`WanVACETransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising, + `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only + `transformer` is used. scheduler ([`UniPCMultistepScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. """ model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer_2"] def __init__( self, @@ -170,6 +180,8 @@ def __init__( transformer: WanVACETransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, + transformer_2: WanVACETransformer3DModel = None, + boundary_ratio: Optional[float] = None, ): super().__init__() @@ -178,9 +190,10 @@ def __init__( text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler, ) - + self.register_to_config(boundary_ratio=boundary_ratio) self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -321,6 +334,7 @@ def check_inputs( video=None, mask=None, reference_images=None, + guidance_scale_2=None, ): base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] if height % base != 0 or width % base != 0: @@ -332,6 +346,8 @@ def check_inputs( raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) + if self.config.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") if prompt is not None and prompt_embeds is not None: raise ValueError( @@ -667,6 +683,7 @@ def __call__( num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -728,6 +745,10 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -793,6 +814,7 @@ def __call__( video, mask, reference_images, + guidance_scale_2 ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -802,7 +824,11 @@ def __call__( num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -896,36 +922,53 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - control_hidden_states=conditioning_latents, - control_hidden_states_scale=conditioning_scale, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with current_model.cache_context("cond"): + noise_pred = current_model( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, control_hidden_states=conditioning_latents, control_hidden_states_scale=conditioning_scale, attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + if self.do_classifier_free_guidance: + with current_model.cache_context("uncond"): + noise_uncond = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + control_hidden_states=conditioning_latents, + control_hidden_states_scale=conditioning_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] From 3c0c521a5d45450ae72e70fe77c935b423315141 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 12 Sep 2025 12:03:07 +0200 Subject: [PATCH 2/5] support Wan2.2-VACE-Fun-A14B --- scripts/convert_wan_to_diffusers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index dadaab4670de..554f7870b7fa 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -998,23 +998,23 @@ def get_args(): image_encoder=image_encoder, image_processor=image_processor, ) - elif "VACE" in args.model_type: + elif "Wan2.2-VACE" in args.model_type: pipe = WanVACEPipeline( transformer=transformer, + transformer_2=transformer_2, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, scheduler=scheduler, + boundary_ratio=0.875, ) - elif "Wan2.2-VACE" in args.model_type: + elif "VACE" in args.model_type: pipe = WanVACEPipeline( transformer=transformer, - transformer_2=transformer_2, text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, scheduler=scheduler, - boundary_ratio=0.875, ) else: pipe = WanPipeline( From bfb87252e2c37830c7deac7af2672a31e690bc72 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 12 Sep 2025 17:30:38 +0200 Subject: [PATCH 3/5] support Wan2.2-VACE-Fun-A14B --- scripts/convert_wan_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 554f7870b7fa..39a364b07d78 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -1008,7 +1008,7 @@ def get_args(): scheduler=scheduler, boundary_ratio=0.875, ) - elif "VACE" in args.model_type: + elif "Wan-VACE" in args.model_type: pipe = WanVACEPipeline( transformer=transformer, text_encoder=text_encoder, From 417692e94d683945be8b81ce514c9ffc9dd46a07 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 13 Sep 2025 08:16:35 +0000 Subject: [PATCH 4/5] Apply style fixes --- src/diffusers/pipelines/wan/pipeline_wan_vace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index 9b3daa469631..eab1aacfc58e 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -814,7 +814,7 @@ def __call__( video, mask, reference_images, - guidance_scale_2 + guidance_scale_2, ) if num_frames % self.vae_scale_factor_temporal != 1: From deae16a7d8f64ea2a2216b2c314674fd85f5386a Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 15 Sep 2025 16:51:03 +0200 Subject: [PATCH 5/5] test --- tests/pipelines/wan/test_wan_vace.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pipelines/wan/test_wan_vace.py b/tests/pipelines/wan/test_wan_vace.py index ed13d5649dc3..f99863c88092 100644 --- a/tests/pipelines/wan/test_wan_vace.py +++ b/tests/pipelines/wan/test_wan_vace.py @@ -87,6 +87,7 @@ def get_dummy_components(self): "scheduler": scheduler, "text_encoder": text_encoder, "tokenizer": tokenizer, + "transformer_2": None, } return components