Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion scripts/convert_wan_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,29 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-VACE-Fun-14B":
config = {
"model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
"vace_in_channels": 96,
},
}
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-I2V-14B-720p":
config = {
"model_id": "Wan-AI/Wan2.2-I2V-A14B",
Expand Down Expand Up @@ -975,7 +998,17 @@ def get_args():
image_encoder=image_encoder,
image_processor=image_processor,
)
elif "VACE" in args.model_type:
elif "Wan2.2-VACE" in args.model_type:
pipe = WanVACEPipeline(
transformer=transformer,
transformer_2=transformer_2,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
boundary_ratio=0.875,
)
elif "Wan-VACE" in args.model_type:
pipe = WanVACEPipeline(
transformer=transformer,
text_encoder=text_encoder,
Expand Down
75 changes: 59 additions & 16 deletions src/diffusers/pipelines/wan/pipeline_wan_vace.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,26 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanTransformer3DModel`]):
transformer ([`WanVACETransformer3DModel`]):
Conditional Transformer to denoise the input latents.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
`transformer` is used.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
"""

model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2"]

def __init__(
self,
Expand All @@ -170,6 +180,8 @@ def __init__(
transformer: WanVACETransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer_2: WanVACETransformer3DModel = None,
boundary_ratio: Optional[float] = None,
):
super().__init__()

Expand All @@ -178,9 +190,10 @@ def __init__(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
transformer_2=transformer_2,
scheduler=scheduler,
)

self.register_to_config(boundary_ratio=boundary_ratio)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
Expand Down Expand Up @@ -321,6 +334,7 @@ def check_inputs(
video=None,
mask=None,
reference_images=None,
guidance_scale_2=None,
):
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
if height % base != 0 or width % base != 0:
Expand All @@ -332,6 +346,8 @@ def check_inputs(
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")

if prompt is not None and prompt_embeds is not None:
raise ValueError(
Expand Down Expand Up @@ -667,6 +683,7 @@ def __call__(
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -728,6 +745,10 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
guidance_scale_2 (`float`, *optional*, defaults to `None`):
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
and the pipeline's `boundary_ratio` are not None.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
Expand Down Expand Up @@ -793,6 +814,7 @@ def __call__(
video,
mask,
reference_images,
guidance_scale_2,
)

if num_frames % self.vae_scale_factor_temporal != 1:
Expand All @@ -802,7 +824,11 @@ def __call__(
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)

if self.config.boundary_ratio is not None and guidance_scale_2 is None:
guidance_scale_2 = guidance_scale

self._guidance_scale = guidance_scale
self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
Expand Down Expand Up @@ -896,36 +922,53 @@ def __call__(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)

if self.config.boundary_ratio is not None:
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
else:
boundary_timestep = None

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

self._current_timestep = t

if boundary_timestep is None or t >= boundary_timestep:
# wan2.1 or high-noise stage in wan2.2
current_model = self.transformer
current_guidance_scale = guidance_scale
else:
# low-noise stage in wan2.2
current_model = self.transformer_2
current_guidance_scale = guidance_scale_2

latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])

noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
control_hidden_states=conditioning_latents,
control_hidden_states_scale=conditioning_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]

if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
with current_model.cache_context("cond"):
noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
control_hidden_states=conditioning_latents,
control_hidden_states_scale=conditioning_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)

if self.do_classifier_free_guidance:
with current_model.cache_context("uncond"):
noise_uncond = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
control_hidden_states=conditioning_latents,
control_hidden_states_scale=conditioning_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
Expand Down
1 change: 1 addition & 0 deletions tests/pipelines/wan/test_wan_vace.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def get_dummy_components(self):
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer_2": None,
}
return components

Expand Down
Loading