Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,8 @@
[
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"WanAutoBlocks",
"WanModularPipeline",
]
)
_import_structure["pipelines"].extend(
Expand Down Expand Up @@ -999,6 +1001,8 @@
from .modular_pipelines import (
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
WanAutoBlocks,
WanModularPipeline,
)
from .pipelines import (
AllegroPipeline,
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/hooks/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _register(cls):
def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_wan import WanAttnProcessor2_0

# AttnProcessor2_0
AttentionProcessorRegistry.register(
Expand All @@ -124,6 +125,14 @@ def _register_attention_processors_metadata():
),
)

# WanAttnProcessor2_0
AttentionProcessorRegistry.register(
model_class=WanAttnProcessor2_0,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
),
)


def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
Expand Down Expand Up @@ -261,4 +270,5 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *

_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
# fmt: on
15 changes: 12 additions & 3 deletions src/diffusers/hooks/layer_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,19 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.nn.functional.scaled_dot_product_attention:
query = kwargs.get("query", None)
key = kwargs.get("key", None)
value = kwargs.get("value", None)
if value is None:
value = args[2]
return value
query = query if query is not None else args[0]
key = key if key is not None else args[1]
value = value if value is not None else args[2]
# If the Q sequence length does not match KV sequence length, methods like
# Perturbed Attention Guidance cannot be used (because the caller expects
# the same sequence length as Q, but if we return V here, it will not match).
# When Q.shape[2] != V.shape[2], PAG will essentially not be applied and
# the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale).
if query.shape[2] == value.shape[2]:
return value
return func(*args, **kwargs)


Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/modular_pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"InsertableDict",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
_import_structure["components_manager"] = ["ComponentsManager"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand Down Expand Up @@ -71,6 +72,7 @@
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
)
from .wan import WanAutoBlocks, WanModularPipeline
else:
import sys

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@
MODULAR_PIPELINE_MAPPING = OrderedDict(
[
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
("wan", "WanModularPipeline"),
]
)

MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
[
("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
("WanModularPipeline", "WanAutoBlocks"),
]
)

Expand Down
66 changes: 66 additions & 0 deletions src/diffusers/modular_pipelines/wan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import TYPE_CHECKING

from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)


_dummy_objects = {}
_import_structure = {}

try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403

_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["encoders"] = ["WanTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"TEXT2VIDEO_BLOCKS",
"WanAutoBeforeDenoiseStep",
"WanAutoBlocks",
"WanAutoBlocks",
"WanAutoDecodeStep",
"WanAutoDenoiseStep",
]
_import_structure["modular_pipeline"] = ["WanModularPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .encoders import WanTextEncoderStep
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
TEXT2VIDEO_BLOCKS,
WanAutoBeforeDenoiseStep,
WanAutoBlocks,
WanAutoDecodeStep,
WanAutoDenoiseStep,
)
from .modular_pipeline import WanModularPipeline
else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)

for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
Loading
Loading