Skip to content

Commit aaf2df8

Browse files
committed
update
1 parent 7e97e43 commit aaf2df8

File tree

2 files changed

+180
-57
lines changed

2 files changed

+180
-57
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 167 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24+
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
2525
from ...utils.torch_utils import maybe_allow_in_graph
26-
from ..attention import FeedForward
27-
from ..attention_processor import Attention
26+
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
2827
from ..cache_utils import CacheMixin
2928
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
3029
from ..modeling_outputs import Transformer2DModelOutput
@@ -35,36 +34,61 @@
3534
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3635

3736

38-
class WanAttnProcessor2_0:
37+
class WanAttnProcessor:
3938
def __init__(self):
4039
if not hasattr(F, "scaled_dot_product_attention"):
41-
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
40+
raise ImportError(
41+
"WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
42+
)
43+
44+
def get_qkv_projections(
45+
self, attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor
46+
):
47+
# encoder_hidden_states is only passed for cross-attention
48+
if encoder_hidden_states is None:
49+
encoder_hidden_states = hidden_states
50+
51+
if attn.fused_projections:
52+
if attn.cross_attention_dim_head is None:
53+
# In self-attention layers, we can fuse the entire QKV projection into a single linear
54+
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
55+
else:
56+
# In cross-attention layers, we can only fuse the KV projections into a single linear
57+
query = attn.to_q(hidden_states)
58+
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
59+
else:
60+
query = attn.to_q(hidden_states)
61+
key = attn.to_k(encoder_hidden_states)
62+
value = attn.to_v(encoder_hidden_states)
63+
return query, key, value
64+
65+
def get_added_kv_projections(self, attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
66+
if attn.fused_projections:
67+
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
68+
else:
69+
key_img = attn.add_k_proj(encoder_hidden_states_img)
70+
value_img = attn.add_v_proj(encoder_hidden_states_img)
71+
return key_img, value_img
4272

4373
def __call__(
4474
self,
45-
attn: Attention,
75+
attn: "WanAttention",
4676
hidden_states: torch.Tensor,
4777
encoder_hidden_states: Optional[torch.Tensor] = None,
4878
attention_mask: Optional[torch.Tensor] = None,
49-
rotary_emb: Optional[torch.Tensor] = None,
79+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
5080
) -> torch.Tensor:
5181
encoder_hidden_states_img = None
5282
if attn.add_k_proj is not None:
5383
# 512 is the context length of the text encoder, hardcoded for now
5484
image_context_length = encoder_hidden_states.shape[1] - 512
5585
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
5686
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
57-
if encoder_hidden_states is None:
58-
encoder_hidden_states = hidden_states
5987

60-
query = attn.to_q(hidden_states)
61-
key = attn.to_k(encoder_hidden_states)
62-
value = attn.to_v(encoder_hidden_states)
88+
query, key, value = self.get_qkv_projections(attn, hidden_states, encoder_hidden_states)
6389

64-
if attn.norm_q is not None:
65-
query = attn.norm_q(query)
66-
if attn.norm_k is not None:
67-
key = attn.norm_k(key)
90+
query = attn.norm_q(query)
91+
key = attn.norm_k(key)
6892

6993
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
7094
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
@@ -92,9 +116,8 @@ def apply_rotary_emb(
92116
# I2V task
93117
hidden_states_img = None
94118
if encoder_hidden_states_img is not None:
95-
key_img = attn.add_k_proj(encoder_hidden_states_img)
119+
key_img, value_img = self.get_added_kv_projections(attn, encoder_hidden_states_img)
96120
key_img = attn.norm_added_k(key_img)
97-
value_img = attn.add_v_proj(encoder_hidden_states_img)
98121

99122
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
100123
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
@@ -119,6 +142,119 @@ def apply_rotary_emb(
119142
return hidden_states
120143

121144

145+
class WanAttnProcessor2_0:
146+
def __new__(cls, *args, **kwargs):
147+
deprecation_message = (
148+
"The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
149+
"Please use WanAttnProcessor instead. "
150+
)
151+
deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
152+
return WanAttnProcessor(*args, **kwargs)
153+
154+
155+
class WanAttention(torch.nn.Module, AttentionModuleMixin):
156+
_default_processor_cls = WanAttnProcessor
157+
_available_processors = [WanAttnProcessor]
158+
159+
def __init__(
160+
self,
161+
dim: int,
162+
heads: int = 8,
163+
dim_head: int = 64,
164+
eps: float = 1e-5,
165+
dropout: float = 0.0,
166+
added_kv_proj_dim: Optional[int] = None,
167+
cross_attention_dim_head: Optional[int] = None,
168+
processor=None,
169+
):
170+
super().__init__()
171+
172+
self.inner_dim = dim_head * heads
173+
self.heads = heads
174+
self.added_kv_proj_dim = added_kv_proj_dim
175+
self.cross_attention_dim_head = cross_attention_dim_head
176+
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
177+
178+
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
179+
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
180+
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
181+
self.to_out = torch.nn.ModuleList(
182+
[
183+
torch.nn.Linear(self.inner_dim, dim, bias=True),
184+
torch.nn.Dropout(dropout),
185+
]
186+
)
187+
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
188+
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
189+
190+
self.add_k_proj = self.add_v_proj = None
191+
if added_kv_proj_dim is not None:
192+
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
193+
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
194+
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
195+
196+
self.set_processor(processor)
197+
198+
def fuse_projections(self):
199+
if getattr(self, "fused_projections", False):
200+
return
201+
202+
if self.cross_attention_dim_head is None:
203+
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
204+
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
205+
out_features, in_features = concatenated_weights.shape
206+
with torch.device("meta"):
207+
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
208+
self.to_qkv.load_state_dict(
209+
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
210+
)
211+
else:
212+
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
213+
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
214+
out_features, in_features = concatenated_weights.shape
215+
with torch.device("meta"):
216+
self.to_kv = nn.Linear(in_features, out_features, bias=True)
217+
self.to_kv.load_state_dict(
218+
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
219+
)
220+
221+
if self.added_kv_proj_dim is not None:
222+
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
223+
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
224+
out_features, in_features = concatenated_weights.shape
225+
with torch.device("meta"):
226+
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
227+
self.to_added_kv.load_state_dict(
228+
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
229+
)
230+
231+
self.fused_projections = True
232+
233+
@torch.no_grad()
234+
def unfuse_projections(self):
235+
if not getattr(self, "fused_projections", False):
236+
return
237+
238+
if hasattr(self, "to_qkv"):
239+
delattr(self, "to_qkv")
240+
if hasattr(self, "to_kv"):
241+
delattr(self, "to_kv")
242+
if hasattr(self, "to_added_kv"):
243+
delattr(self, "to_added_kv")
244+
245+
self.fused_projections = False
246+
247+
def forward(
248+
self,
249+
hidden_states: torch.Tensor,
250+
encoder_hidden_states: torch.Tensor,
251+
attention_mask: Optional[torch.Tensor] = None,
252+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
253+
**kwargs,
254+
) -> torch.Tensor:
255+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
256+
257+
122258
class WanImageEmbedding(torch.nn.Module):
123259
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
124260
super().__init__()
@@ -266,33 +402,24 @@ def __init__(
266402

267403
# 1. Self-attention
268404
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
269-
self.attn1 = Attention(
270-
query_dim=dim,
405+
self.attn1 = WanAttention(
406+
dim=dim,
271407
heads=num_heads,
272-
kv_heads=num_heads,
273408
dim_head=dim // num_heads,
274-
qk_norm=qk_norm,
275409
eps=eps,
276-
bias=True,
277-
cross_attention_dim=None,
278-
out_bias=True,
279-
processor=WanAttnProcessor2_0(),
410+
cross_attention_dim_head=None,
411+
processor=WanAttnProcessor(),
280412
)
281413

282414
# 2. Cross-attention
283-
self.attn2 = Attention(
284-
query_dim=dim,
415+
self.attn2 = WanAttention(
416+
dim=dim,
285417
heads=num_heads,
286-
kv_heads=num_heads,
287418
dim_head=dim // num_heads,
288-
qk_norm=qk_norm,
289419
eps=eps,
290-
bias=True,
291-
cross_attention_dim=None,
292-
out_bias=True,
293420
added_kv_proj_dim=added_kv_proj_dim,
294-
added_proj_bias=True,
295-
processor=WanAttnProcessor2_0(),
421+
cross_attention_dim_head=dim // num_heads,
422+
processor=WanAttnProcessor(),
296423
)
297424
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
298425

@@ -315,12 +442,12 @@ def forward(
315442

316443
# 1. Self-attention
317444
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
318-
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
445+
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
319446
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
320447

321448
# 2. Cross-attention
322449
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
323-
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
450+
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
324451
hidden_states = hidden_states + attn_output
325452

326453
# 3. Feed-forward
@@ -333,7 +460,9 @@ def forward(
333460
return hidden_states
334461

335462

336-
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
463+
class WanTransformer3DModel(
464+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
465+
):
337466
r"""
338467
A Transformer model for video-like data used in the Wan model.
339468

src/diffusers/models/transformers/transformer_wan_vace.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,17 @@
2222
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
2323
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2424
from ..attention import FeedForward
25-
from ..attention_processor import Attention
2625
from ..cache_utils import CacheMixin
2726
from ..modeling_outputs import Transformer2DModelOutput
2827
from ..modeling_utils import ModelMixin
2928
from ..normalization import FP32LayerNorm
30-
from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock
29+
from .transformer_wan import (
30+
WanAttention,
31+
WanAttnProcessor,
32+
WanRotaryPosEmbed,
33+
WanTimeTextImageEmbedding,
34+
WanTransformerBlock,
35+
)
3136

3237

3338
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -55,33 +60,22 @@ def __init__(
5560

5661
# 2. Self-attention
5762
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
58-
self.attn1 = Attention(
59-
query_dim=dim,
63+
self.attn1 = WanAttention(
64+
dim=dim,
6065
heads=num_heads,
61-
kv_heads=num_heads,
6266
dim_head=dim // num_heads,
63-
qk_norm=qk_norm,
6467
eps=eps,
65-
bias=True,
66-
cross_attention_dim=None,
67-
out_bias=True,
68-
processor=WanAttnProcessor2_0(),
68+
processor=WanAttnProcessor(),
6969
)
7070

7171
# 3. Cross-attention
72-
self.attn2 = Attention(
73-
query_dim=dim,
72+
self.attn2 = WanAttention(
73+
dim=dim,
7474
heads=num_heads,
75-
kv_heads=num_heads,
7675
dim_head=dim // num_heads,
77-
qk_norm=qk_norm,
7876
eps=eps,
79-
bias=True,
80-
cross_attention_dim=None,
81-
out_bias=True,
8277
added_kv_proj_dim=added_kv_proj_dim,
83-
added_proj_bias=True,
84-
processor=WanAttnProcessor2_0(),
78+
processor=WanAttnProcessor(),
8579
)
8680
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
8781

0 commit comments

Comments
 (0)