Skip to content

Commit 5b8927a

Browse files
committed
remove transpose; fix rope shape
1 parent 4940b21 commit 5b8927a

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ def __call__(
9393
query = attn.norm_q(query)
9494
key = attn.norm_k(key)
9595

96-
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
97-
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
98-
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
96+
query = query.unflatten(2, (attn.heads, -1))
97+
key = key.unflatten(2, (attn.heads, -1))
98+
value = value.unflatten(2, (attn.heads, -1))
9999

100100
if rotary_emb is not None:
101101

@@ -104,8 +104,7 @@ def apply_rotary_emb(
104104
freqs_cos: torch.Tensor,
105105
freqs_sin: torch.Tensor,
106106
):
107-
x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
108-
x1, x2 = x[..., 0], x[..., 1]
107+
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
109108
cos = freqs_cos[..., 0::2]
110109
sin = freqs_sin[..., 1::2]
111110
out = torch.empty_like(hidden_states)
@@ -122,8 +121,8 @@ def apply_rotary_emb(
122121
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
123122
key_img = attn.norm_added_k(key_img)
124123

125-
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
126-
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
124+
key_img = key_img.unflatten(2, (attn.heads, -1))
125+
value_img = value_img.unflatten(2, (attn.heads, -1))
127126

128127
hidden_states_img = dispatch_attention_fn(
129128
query,
@@ -134,7 +133,7 @@ def apply_rotary_emb(
134133
is_causal=False,
135134
backend=self._attention_backend,
136135
)
137-
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
136+
hidden_states_img = hidden_states_img.flatten(2, 3)
138137
hidden_states_img = hidden_states_img.type_as(query)
139138

140139
hidden_states = dispatch_attention_fn(
@@ -146,7 +145,7 @@ def apply_rotary_emb(
146145
is_causal=False,
147146
backend=self._attention_backend,
148147
)
149-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
148+
hidden_states = hidden_states.flatten(2, 3)
150149
hidden_states = hidden_states.type_as(query)
151150

152151
if hidden_states_img is not None:
@@ -395,8 +394,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
395394
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
396395
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
397396

398-
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
399-
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
397+
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
398+
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
400399

401400
return freqs_cos, freqs_sin
402401

0 commit comments

Comments
 (0)