Skip to content

Commit bcc02e7

Browse files
committed
update
1 parent 4018ed8 commit bcc02e7

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

src/diffusers/hooks/_helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _register(cls):
107107
def _register_attention_processors_metadata():
108108
from ..models.attention_processor import AttnProcessor2_0
109109
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
110+
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
110111

111112
# AttnProcessor2_0
112113
AttentionProcessorRegistry.register(
@@ -124,6 +125,14 @@ def _register_attention_processors_metadata():
124125
),
125126
)
126127

128+
# WanAttnProcessor2_0
129+
AttentionProcessorRegistry.register(
130+
model_class=WanAttnProcessor2_0,
131+
metadata=AttentionProcessorMetadata(
132+
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
133+
),
134+
)
135+
127136

128137
def _register_transformer_blocks_metadata():
129138
from ..models.attention import BasicTransformerBlock
@@ -261,4 +270,5 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
261270

262271
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
263272
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
273+
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
264274
# fmt: on

src/diffusers/hooks/layer_skip.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,19 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
9191
if kwargs is None:
9292
kwargs = {}
9393
if func is torch.nn.functional.scaled_dot_product_attention:
94+
query = kwargs.get("query", None)
95+
key = kwargs.get("key", None)
9496
value = kwargs.get("value", None)
95-
if value is None:
96-
value = args[2]
97-
return value
97+
query = query if query is not None else args[0]
98+
key = key if key is not None else args[1]
99+
value = value if value is not None else args[2]
100+
# If the Q sequence length does not match KV sequence length, methods like
101+
# Perturbed Attention Guidance cannot be used (because the caller expects
102+
# the same sequence length as Q, but if we return V here, it will not match).
103+
# When Q.shape[2] != V.shape[2], PAG will essentially not be applied and
104+
# the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale).
105+
if query.shape[2] == value.shape[2]:
106+
return value
98107
return func(*args, **kwargs)
99108

100109

0 commit comments

Comments
 (0)