Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def __init__(
added_kv_proj_dim: Optional[int] = None,
cross_attention_dim_head: Optional[int] = None,
processor=None,
is_cross_attention=None
):
super().__init__()

Expand Down Expand Up @@ -207,6 +208,8 @@ def __init__(
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)

self.is_cross_attention = cross_attention_dim_head is not None

self.set_processor(processor)

def fuse_projections(self):
Expand Down
Loading