diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 2b6d5953fc4f..968a0369c243 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -180,6 +180,7 @@ def __init__( added_kv_proj_dim: Optional[int] = None, cross_attention_dim_head: Optional[int] = None, processor=None, + is_cross_attention=None, ): super().__init__() @@ -207,6 +208,8 @@ def __init__( self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) + self.is_cross_attention = cross_attention_dim_head is not None + self.set_processor(processor) def fuse_projections(self):