diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 30d497892ff5..80c78b8a96d5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -366,6 +366,8 @@ [ "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", + "WanAutoBlocks", + "WanModularPipeline", ] ) _import_structure["pipelines"].extend( @@ -999,6 +1001,8 @@ from .modular_pipelines import ( StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, + WanAutoBlocks, + WanModularPipeline, ) from .pipelines import ( AllegroPipeline, diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 960d14e6fa2a..5fa047257f09 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -107,6 +107,7 @@ def _register(cls): def _register_attention_processors_metadata(): from ..models.attention_processor import AttnProcessor2_0 from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor + from ..models.transformers.transformer_wan import WanAttnProcessor2_0 # AttnProcessor2_0 AttentionProcessorRegistry.register( @@ -124,6 +125,14 @@ def _register_attention_processors_metadata(): ), ) + # WanAttnProcessor2_0 + AttentionProcessorRegistry.register( + model_class=WanAttnProcessor2_0, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0, + ), + ) + def _register_transformer_blocks_metadata(): from ..models.attention import BasicTransformerBlock @@ -261,4 +270,5 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, * _skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states +_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states # fmt: on diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 14e6c2f8881e..0ce02e987d09 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -91,10 +91,19 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func is torch.nn.functional.scaled_dot_product_attention: + query = kwargs.get("query", None) + key = kwargs.get("key", None) value = kwargs.get("value", None) - if value is None: - value = args[2] - return value + query = query if query is not None else args[0] + key = key if key is not None else args[1] + value = value if value is not None else args[2] + # If the Q sequence length does not match KV sequence length, methods like + # Perturbed Attention Guidance cannot be used (because the caller expects + # the same sequence length as Q, but if we return V here, it will not match). + # When Q.shape[2] != V.shape[2], PAG will essentially not be applied and + # the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale). + if query.shape[2] == value.shape[2]: + return value return func(*args, **kwargs) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index bf34eed28b8c..b3025bf4d3ab 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -40,6 +40,7 @@ "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] + _import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"] _import_structure["components_manager"] = ["ComponentsManager"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -71,6 +72,7 @@ StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, ) + from .wan import WanAutoBlocks, WanModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 6f1c617bc261..8838a1cb5942 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -60,12 +60,14 @@ MODULAR_PIPELINE_MAPPING = OrderedDict( [ ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), + ("wan", "WanModularPipeline"), ] ) MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict( [ ("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"), + ("WanModularPipeline", "WanAutoBlocks"), ] ) diff --git a/src/diffusers/modular_pipelines/wan/__init__.py b/src/diffusers/modular_pipelines/wan/__init__.py new file mode 100644 index 000000000000..7b548e003c63 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/__init__.py @@ -0,0 +1,66 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["encoders"] = ["WanTextEncoderStep"] + _import_structure["modular_blocks"] = [ + "ALL_BLOCKS", + "AUTO_BLOCKS", + "TEXT2VIDEO_BLOCKS", + "WanAutoBeforeDenoiseStep", + "WanAutoBlocks", + "WanAutoBlocks", + "WanAutoDecodeStep", + "WanAutoDenoiseStep", + ] + _import_structure["modular_pipeline"] = ["WanModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .encoders import WanTextEncoderStep + from .modular_blocks import ( + ALL_BLOCKS, + AUTO_BLOCKS, + TEXT2VIDEO_BLOCKS, + WanAutoBeforeDenoiseStep, + WanAutoBlocks, + WanAutoDecodeStep, + WanAutoDenoiseStep, + ) + from .modular_pipeline import WanModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py new file mode 100644 index 000000000000..ef65b6453725 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -0,0 +1,365 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Union + +import torch + +from ...schedulers import UniPCMultistepScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import WanModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class WanInputStep(PipelineBlock): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_videos_per_prompt." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_videos_per_prompt", default=1), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields + description="negative text embeddings used to guide the image generation", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_videos_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + + return components, state + + +class WanSetTimestepsStep(PipelineBlock): + model_name = "wan" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", UniPCMultistepScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + block_state.device, + block_state.timesteps, + block_state.sigmas, + ) + + self.set_block_state(state, block_state) + return components, state + + +class WanPrepareLatentsStep(PipelineBlock): + model_name = "wan" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-video generation process" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("num_frames", type_hint=int), + InputParam("latents", type_hint=Optional[torch.Tensor]), + InputParam("num_videos_per_prompt", type_hint=int, default=1), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_videos_per_prompt`. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + if block_state.num_frames is not None and ( + block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 + ): + raise ValueError( + f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." + ) + + @staticmethod + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents with self->comp + def prepare_latents( + comp, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // comp.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // comp.vae_scale_factor_spatial, + int(width) // comp.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.num_frames = block_state.num_frames or components.default_num_frames + block_state.device = components._execution_device + block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality + block_state.num_channels_latents = components.num_channels_latents + + self.check_inputs(components, block_state) + + block_state.latents = self.prepare_latents( + components, + block_state.batch_size * block_state.num_videos_per_prompt, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.num_frames, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/wan/decoders.py b/src/diffusers/modular_pipelines/wan/decoders.py new file mode 100644 index 000000000000..4fadeed4b954 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/decoders.py @@ -0,0 +1,105 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple, Union + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLWan +from ...utils import logging +from ...video_processor import VideoProcessor +from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class WanDecodeStep(PipelineBlock): + model_name = "wan" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ) + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "videos", + type_hint=Union[List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray]], + description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae_dtype = components.vae.dtype + + if not block_state.output_type == "latent": + latents = block_state.latents + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + latents = latents.to(vae_dtype) + block_state.videos = components.vae.decode(latents, return_dict=False)[0] + else: + block_state.videos = block_state.latents + + block_state.videos = components.video_processor.postprocess_video( + block_state.videos, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py new file mode 100644 index 000000000000..76c5cda5f95e --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -0,0 +1,261 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import WanTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + PipelineBlock, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import WanModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class WanLoopDenoiser(PipelineBlock): + model_name = "wan" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", WanTransformer3DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("attention_kwargs"), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds. " + "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ), + ), + ] + + @torch.no_grad() + def __call__( + self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields = { + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + } + transformer_dtype = components.transformer.dtype + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latents.to(transformer_dtype), + timestep=t.flatten(), + encoder_hidden_states=prompt_embeds, + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.transformer) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + + +class WanLoopAfterDenoiser(PipelineBlock): + model_name = "wan" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", UniPCMultistepScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # Perform scheduler step using the predicted output + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred.float(), + t, + block_state.latents.float(), + **block_state.scheduler_step_kwargs, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0}), + default_creation_method="from_config", + ), + ComponentSpec("scheduler", UniPCMultistepScheduler), + ComponentSpec("transformer", WanTransformer3DModel), + ] + + @property + def loop_intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +class WanDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanLoopDenoiser, + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports both text2vid tasks." + ) diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py new file mode 100644 index 000000000000..b2ecfd1aa61a --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -0,0 +1,242 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import List, Optional, Union + +import regex as re +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...utils import is_ftfy_available, logging +from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import WanModularPipeline + + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class WanTextEncoderStep(PipelineBlock): + model_name = "wan" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", UMT5EncoderModel), + ComponentSpec("tokenizer", AutoTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("negative_prompt"), + InputParam("attention_kwargs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="negative text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + + @staticmethod + def _get_t5_prompt_embeds( + components, + prompt: Union[str, List[str]], + max_sequence_length: int, + device: torch.device, + ): + dtype = components.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + + text_inputs = components.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + return prompt_embeds + + @staticmethod + def encode_prompt( + components, + prompt: str, + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + prepare_unconditional_embeds: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of videos that should be generated per prompt + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + max_sequence_length (`int`, defaults to `512`): + The maximum number of text tokens to be used for the generation process. + """ + device = device or components._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device) + + if prepare_unconditional_embeds and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds( + components, negative_prompt, max_sequence_length, device + ) + + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if prepare_unconditional_embeds: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + # Encode input prompt + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + ) = self.encode_prompt( + components, + block_state.prompt, + block_state.device, + 1, + block_state.prepare_unconditional_embeds, + block_state.negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + ) + + # Add outputs + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py new file mode 100644 index 000000000000..5f4c1a983566 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -0,0 +1,144 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + WanInputStep, + WanPrepareLatentsStep, + WanSetTimestepsStep, +) +from .decoders import WanDecodeStep +from .denoise import WanDenoiseStep +from .encoders import WanTextEncoderStep + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# before_denoise: text2vid +class WanBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + WanInputStep, + WanSetTimestepsStep, + WanPrepareLatentsStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + ) + + +# before_denoise: all task (text2vid,) +class WanAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [ + WanBeforeDenoiseStep, + ] + block_names = ["text2vid"] + block_trigger_inputs = [None] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is an auto pipeline block that works for text2vid.\n" + + " - `WanBeforeDenoiseStep` (text2vid) is used.\n" + ) + + +# denoise: text2vid +class WanAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [ + WanDenoiseStep, + ] + block_names = ["denoise"] + block_trigger_inputs = [None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2vid tasks.." + " - `WanDenoiseStep` (denoise) for text2vid tasks." + ) + + +# decode: all task (text2img, img2img, inpainting) +class WanAutoDecodeStep(AutoPipelineBlocks): + block_classes = [WanDecodeStep] + block_names = ["non-inpaint"] + block_trigger_inputs = [None] + + @property + def description(self): + return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`" + + +# text2vid +class WanAutoBlocks(SequentialPipelineBlocks): + block_classes = [ + WanTextEncoderStep, + WanAutoBeforeDenoiseStep, + WanAutoDenoiseStep, + WanAutoDecodeStep, + ] + block_names = [ + "text_encoder", + "before_denoise", + "denoise", + "decoder", + ] + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-video using Wan.\n" + + "- for text-to-video generation, all you need to provide is `prompt`" + ) + + +TEXT2VIDEO_BLOCKS = InsertableDict( + [ + ("text_encoder", WanTextEncoderStep), + ("input", WanInputStep), + ("set_timesteps", WanSetTimestepsStep), + ("prepare_latents", WanPrepareLatentsStep), + ("denoise", WanDenoiseStep), + ("decode", WanDecodeStep), + ] +) + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", WanTextEncoderStep), + ("before_denoise", WanAutoBeforeDenoiseStep), + ("denoise", WanAutoDenoiseStep), + ("decode", WanAutoDecodeStep), + ] +) + + +ALL_BLOCKS = { + "text2video": TEXT2VIDEO_BLOCKS, + "auto": AUTO_BLOCKS, +} diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py new file mode 100644 index 000000000000..4d86e0d08e59 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py @@ -0,0 +1,90 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import WanLoraLoaderMixin +from ...pipelines.pipeline_utils import StableDiffusionMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class WanModularPipeline( + ModularPipeline, + StableDiffusionMixin, + WanLoraLoaderMixin, +): + """ + A ModularPipeline for Wan. + + + + This is an experimental feature and is likely to change in the future. + + + """ + + @property + def default_height(self): + return self.default_sample_height * self.vae_scale_factor_spatial + + @property + def default_width(self): + return self.default_sample_width * self.vae_scale_factor_spatial + + @property + def default_num_frames(self): + return (self.default_sample_num_frames - 1) * self.vae_scale_factor_temporal + 1 + + @property + def default_sample_height(self): + return 60 + + @property + def default_sample_width(self): + return 104 + + @property + def default_sample_num_frames(self): + return 21 + + @property + def vae_scale_factor_spatial(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def vae_scale_factor_temporal(self): + vae_scale_factor = 4 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** sum(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def num_channels_transformer(self): + num_channels_transformer = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_transformer = self.transformer.config.in_channels + return num_channels_transformer + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.z_dim + return num_channels_latents diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index dde5bbda60bf..7538635c808e 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -32,6 +32,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class WanAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class WanModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AllegroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"]