Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 49 additions & 22 deletions src/diffusers/pipelines/wan/pipeline_wan_vace.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanVACETransformer3DModel`]):
Conditional Transformer to denoise the input latents.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
`transformer` is used.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
transformer ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
boundary_timestep. If `None`, only the available transformer is used for the entire denoising process.
"""

model_cpu_offload_seq = "text_encoder->transformer->vae"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this not need?

    model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"

When you are using two transformers I'm guessing you are using other WAN2.2 ones are these compatible with VACE?

_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2"]
_optional_components = ["transformer", "transformer_2"]

def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanVACETransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer: WanVACETransformer3DModel = None,
transformer_2: WanVACETransformer3DModel = None,
boundary_ratio: Optional[float] = None,
):
Expand Down Expand Up @@ -336,7 +338,15 @@ def check_inputs(
reference_images=None,
guidance_scale_2=None,
):
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
if self.transformer is not None:
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
elif self.transformer_2 is not None:
base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1]
else:
raise ValueError(
"`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
)

if height % base != 0 or width % base != 0:
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")

Expand Down Expand Up @@ -414,7 +424,11 @@ def preprocess_conditions(
device: Optional[torch.device] = None,
):
if video is not None:
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
base = self.vae_scale_factor_spatial * (
self.transformer.config.patch_size[1]
if self.transformer is not None
else self.transformer_2.config.patch_size[1]
)
video_height, video_width = self.video_processor.get_default_height_width(video[0])

if video_height * video_width > height * width:
Expand Down Expand Up @@ -589,7 +603,11 @@ def prepare_masks(
"Generating with more than one video is not yet supported. This may be supported in the future."
)

transformer_patch_size = self.transformer.config.patch_size[1]
transformer_patch_size = (
self.transformer.config.patch_size[1]
if self.transformer is not None
else self.transformer_2.config.patch_size[1]
)

mask_list = []
for mask_, reference_images_batch in zip(mask, reference_images):
Expand Down Expand Up @@ -844,20 +862,25 @@ def __call__(
batch_size = prompt_embeds.shape[0]

vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype

vace_layers = (
self.transformer.config.vace_layers
if self.transformer is not None
else self.transformer_2.config.vace_layers
)
if isinstance(conditioning_scale, (int, float)):
conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers)
conditioning_scale = [conditioning_scale] * len(vace_layers)
if isinstance(conditioning_scale, list):
if len(conditioning_scale) != len(self.transformer.config.vace_layers):
if len(conditioning_scale) != len(vace_layers):
raise ValueError(
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}."
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}."
)
conditioning_scale = torch.tensor(conditioning_scale)
if isinstance(conditioning_scale, torch.Tensor):
if conditioning_scale.size(0) != len(self.transformer.config.vace_layers):
if conditioning_scale.size(0) != len(vace_layers):
raise ValueError(
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}."
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}."
)
conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)

Expand Down Expand Up @@ -900,7 +923,11 @@ def __call__(
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
conditioning_latents = conditioning_latents.to(transformer_dtype)

num_channels_latents = self.transformer.config.in_channels
num_channels_latents = (
self.transformer.config.in_channels
if self.transformer is not None
else self.transformer_2.config.in_channels
)
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
Expand Down Expand Up @@ -968,7 +995,7 @@ def __call__(
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
Expand Down
83 changes: 81 additions & 2 deletions tests/pipelines/wan/test_wan_vace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
import unittest

import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel

from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
from diffusers import (
AutoencoderKLWan,
FlowMatchEulerDiscreteScheduler,
UniPCMultistepScheduler,
WanVACEPipeline,
WanVACETransformer3DModel,
)

from ...testing_utils import enable_full_determinism
from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin

Expand Down Expand Up @@ -212,3 +219,75 @@ def test_float16_inference(self):
)
def test_save_load_float16(self):
pass

def test_inference_with_only_transformer(self):
components = self.get_dummy_components()
components["transformer_2"] = None
components["boundary_ratio"] = 0.0
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device)
video = pipe(**inputs).frames[0]
assert video.shape == (17, 3, 16, 16)

def test_inference_with_only_transformer_2(self):
components = self.get_dummy_components()
components["transformer_2"] = components["transformer"]
components["transformer"] = None

# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
# because starting timestep t == 1000 == boundary_timestep
components["scheduler"] = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
)

components["boundary_ratio"] = 1.0
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(torch_device)
video = pipe(**inputs).frames[0]
assert video.shape == (17, 3, 16, 16)

def test_save_load_optional_components(self, expected_max_difference=1e-4):
optional_component = ["transformer", "image_encoder", "image_processor"]

components = self.get_dummy_components()
for component in optional_component:
components[component] = None

components["boundary_ratio"] = 1.0

pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]

with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)

for component in optional_component:
assert getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading."

inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]

max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
assert max_diff < expected_max_difference, "Outputs exceed expecpted maximum difference"
Loading