Skip to content

Commit 11d22e0

Browse files
samuelt0github-actions[bot]a-r-r-o-w
authored
Cross attention module to Wan Attention (#12058)
* Cross attention module to Wan Attention * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Aryan <[email protected]>
1 parent 9a38fab commit 11d22e0

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(
180180
added_kv_proj_dim: Optional[int] = None,
181181
cross_attention_dim_head: Optional[int] = None,
182182
processor=None,
183+
is_cross_attention=None,
183184
):
184185
super().__init__()
185186

@@ -207,6 +208,8 @@ def __init__(
207208
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
208209
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
209210

211+
self.is_cross_attention = cross_attention_dim_head is not None
212+
210213
self.set_processor(processor)
211214

212215
def fuse_projections(self):

0 commit comments

Comments
 (0)