Skip to content

Commit c539784

Browse files
committed
refactor rope
1 parent 3423a98 commit c539784

File tree

3 files changed

+104
-173
lines changed

3 files changed

+104
-173
lines changed

scripts/convert_wan_to_diffusers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
"head.modulation": "scale_shift_table",
2626
"head.head": "proj_out",
2727
"modulation": "scale_shift_table",
28+
"ffn.0": "ffn.net.0.proj",
29+
"ffn.2": "ffn.net.2",
2830
# Hack to swap the layer names
2931
# The original model calls the norms in following order: norm1, norm3, norm2
3032
# We convert it to: norm1, norm2, norm3

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 96 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ...utils import logging
2424
from ..attention import FeedForward
2525
from ..attention_processor import Attention
26-
from ..embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
26+
from ..embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection, get_1d_rotary_pos_embed
2727
from ..modeling_outputs import Transformer2DModelOutput
2828
from ..modeling_utils import ModelMixin
2929
from ..normalization import FP32LayerNorm
@@ -45,14 +45,8 @@ def __call__(
4545
hidden_states: torch.Tensor,
4646
encoder_hidden_states: Optional[torch.Tensor] = None,
4747
attention_mask: Optional[torch.Tensor] = None,
48-
grid_sizes: Optional[torch.Tensor] = None,
49-
freqs: Optional[torch.Tensor] = None,
48+
rotary_emb: Optional[torch.Tensor] = None,
5049
) -> torch.Tensor:
51-
batch_size, _, _ = (
52-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
53-
)
54-
55-
# i2v task
5650
encoder_hidden_states_img = None
5751
if attn.add_k_proj is not None:
5852
encoder_hidden_states_img = encoder_hidden_states[:, :257]
@@ -69,19 +63,20 @@ def __call__(
6963
if attn.norm_k is not None:
7064
key = attn.norm_k(key)
7165

72-
query = query.unflatten(2, (attn.heads, -1))
73-
key = key.unflatten(2, (attn.heads, -1))
74-
value = value.unflatten(2, (attn.heads, -1))
75-
76-
if grid_sizes is not None and freqs is not None:
77-
query = apply_rotary_emb(query, grid_sizes, freqs)
78-
key = apply_rotary_emb(key, grid_sizes, freqs)
79-
80-
query = query.transpose(1, 2)
81-
key = key.transpose(1, 2)
82-
value = value.transpose(1, 2)
66+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
67+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
68+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
69+
70+
if rotary_emb is not None:
71+
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
72+
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
73+
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
74+
return x_out.type_as(hidden_states)
75+
76+
query = apply_rotary_emb(query, rotary_emb)
77+
key = apply_rotary_emb(key, rotary_emb)
8378

84-
# i2v task
79+
# I2V task
8580
hidden_states_img = None
8681
if encoder_hidden_states_img is not None:
8782
key_img = attn.add_k_proj(encoder_hidden_states_img)
@@ -111,45 +106,6 @@ def __call__(
111106
return hidden_states
112107

113108

114-
@torch.cuda.amp.autocast(enabled=False)
115-
def rope_params(max_seq_len, dim, theta=10000):
116-
assert dim % 2 == 0
117-
freqs = torch.outer(
118-
torch.arange(max_seq_len),
119-
1.0 / torch.pow(theta,
120-
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
121-
freqs = torch.polar(torch.ones_like(freqs), freqs)
122-
return freqs
123-
124-
125-
def apply_rotary_emb(hidden_states: torch.Tensor, grid_sizes, freqs):
126-
n, c = hidden_states.size(2), hidden_states.size(3) // 2
127-
128-
# split freqs
129-
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
130-
131-
# loop over samples
132-
output = []
133-
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
134-
seq_len = f * h * w
135-
136-
# precompute multipliers
137-
x_i = torch.view_as_complex(hidden_states[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
138-
freqs_i = torch.cat([
139-
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
140-
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
141-
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
142-
], dim=-1).reshape(seq_len, 1, -1)
143-
144-
# apply rotary embedding
145-
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
146-
x_i = torch.cat([x_i, hidden_states[i, seq_len:]])
147-
148-
# append to collection
149-
output.append(x_i)
150-
return torch.stack(output).type_as(hidden_states)
151-
152-
153109
class WanImageEmbedding(torch.nn.Module):
154110
def __init__(self, in_features: int, out_features: int):
155111
super().__init__()
@@ -188,10 +144,8 @@ def __init__(
188144

189145
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None):
190146
timestep = self.timesteps_proj(timestep)
191-
with torch.amp.autocast(str(encoder_hidden_states.device), dtype=torch.float32):
192-
temb = self.time_embedder(timestep)
193-
timestep_proj = self.time_proj(self.act_fn(temb))
194-
assert temb.dtype == torch.float32 and timestep_proj.dtype == torch.float32
147+
temb = self.time_embedder(timestep.type_as(encoder_hidden_states))
148+
timestep_proj = self.time_proj(self.act_fn(temb))
195149

196150
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
197151
if encoder_hidden_states_image is not None:
@@ -200,15 +154,49 @@ def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, e
200154
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
201155

202156

157+
class WanRotaryPosEmbed(nn.Module):
158+
def __init__(self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0):
159+
super().__init__()
160+
161+
self.attention_head_dim = attention_head_dim
162+
self.patch_size = patch_size
163+
self.max_seq_len = max_seq_len
164+
165+
h_dim = w_dim = 2 * (attention_head_dim // 6)
166+
t_dim = attention_head_dim - h_dim - w_dim
167+
168+
freqs = []
169+
for dim in [t_dim, h_dim, w_dim]:
170+
freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64)
171+
freqs.append(freq)
172+
self.freqs = torch.cat(freqs, dim=1)
173+
174+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
175+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
176+
p_t, p_h, p_w = self.patch_size
177+
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
178+
179+
self.freqs = self.freqs.to(hidden_states.device)
180+
freqs = self.freqs.split_with_sizes(
181+
[self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), self.attention_head_dim // 6, self.attention_head_dim // 6], dim=1
182+
)
183+
184+
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
185+
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
186+
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
187+
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
188+
return freqs
189+
190+
203191
class WanTransformerBlock(nn.Module):
204192
def __init__(self,
205-
dim,
206-
ffn_dim,
207-
num_heads,
208-
qk_norm=True,
209-
cross_attn_norm=False,
210-
eps=1e-6,
211-
added_kv_proj_dim=None
193+
dim: int,
194+
ffn_dim: int,
195+
num_heads: int,
196+
qk_norm: str = "rms_norm_across_heads",
197+
cross_attn_norm: bool = False,
198+
eps: float = 1e-6,
199+
added_kv_proj_dim: Optional[int] = None
212200
):
213201
super().__init__()
214202
self.dim = dim
@@ -248,54 +236,37 @@ def __init__(self,
248236
added_proj_bias=True,
249237
processor=WanAttnProcessor2_0(),
250238
)
251-
252239
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
253-
254-
self.ffn = nn.Sequential(
255-
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
256-
nn.Linear(ffn_dim, dim)
257-
)
240+
241+
# 3. Feed-forward
242+
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
258243
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
259244

260245
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
261246

262247
def forward(
263248
self,
264249
hidden_states: torch.Tensor,
265-
temb: torch.Tensor,
266250
encoder_hidden_states: torch.Tensor,
267-
grid_sizes,
268-
freqs,
251+
temb: torch.Tensor,
252+
rotary_emb: torch.Tensor,
269253
) -> torch.Tensor:
270-
assert temb.dtype == torch.float32
271-
with torch.amp.autocast(str(temb.device), dtype=torch.float32):
272-
temb = (self.scale_shift_table + temb).chunk(6, dim=1)
273-
assert temb[0].dtype == torch.float32
254+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (self.scale_shift_table + temb).chunk(6, dim=1)
274255

275256
# 1. Self-attention
276-
attn_hidden_states = (self.norm1(hidden_states.float()) * (1 + temb[1]) + temb[0]).type_as(hidden_states)
277-
278-
attn_hidden_states = self.attn1(
279-
hidden_states=attn_hidden_states,
280-
grid_sizes=grid_sizes,
281-
freqs=freqs,
282-
)
283-
hidden_states = (hidden_states.float() + attn_hidden_states.float() * temb[2]).type_as(hidden_states)
257+
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
258+
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
259+
hidden_states = hidden_states + attn_output * gate_msa
284260

285261
# 2. Cross-attention
286-
attn_hidden_states = self.norm2(hidden_states)
287-
attn_hidden_states = self.attn2(
288-
hidden_states=attn_hidden_states,
289-
encoder_hidden_states=encoder_hidden_states,
290-
grid_sizes=None,
291-
freqs=None,
292-
)
293-
hidden_states = hidden_states + attn_hidden_states
262+
norm_hidden_states = self.norm2(hidden_states)
263+
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
264+
hidden_states = hidden_states + attn_output
294265

295266
# 3. Feed-forward
296-
ffn_hidden_states = (self.norm3(hidden_states).float() * (1 + temb[4]) + temb[3]).type_as(hidden_states)
297-
ffn_hidden_states = self.ffn(ffn_hidden_states)
298-
hidden_states = (hidden_states.float() + ffn_hidden_states.float() * temb[5]).type_as(hidden_states)
267+
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
268+
ff_output = self.ffn(norm_hidden_states)
269+
hidden_states = hidden_states + ff_output * c_gate_msa
299270

300271
return hidden_states
301272

@@ -338,7 +309,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin):
338309
"""
339310

340311
_supports_gradient_checkpointing = True
341-
_skip_layerwise_casting_patterns = ["patch_embedding", "text_embedding", "time_embedding", "time_projection", "norm"]
312+
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
342313
_no_split_modules = ["WanTransformerBlock"]
343314

344315
@register_to_config
@@ -358,16 +329,15 @@ def __init__(
358329
eps: float = 1e-6,
359330
image_embedding_dim: Optional[int] = None,
360331
added_kv_proj_dim: Optional[int] = None,
332+
rope_max_seq_len: int = 1024,
361333
) -> None:
362334
super().__init__()
363335

364336
inner_dim = num_attention_heads * attention_head_dim
365337
out_channels = out_channels or in_channels
366338

367-
self.out_channels = out_channels
368-
self.patch_size = patch_size
369-
370-
# 1. Patch embedding
339+
# 1. Patch & position embedding
340+
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
371341
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
372342

373343
# 2. Condition embeddings
@@ -391,14 +361,6 @@ def __init__(
391361
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
392362
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
393363

394-
# buffers (don't use register_buffer otherwise dtype will be changed in to())
395-
assert attention_head_dim % 2 == 0
396-
self.freqs = torch.cat([
397-
rope_params(1024, attention_head_dim - 4 * (attention_head_dim // 6)),
398-
rope_params(1024, 2 * (attention_head_dim // 6)),
399-
rope_params(1024, 2 * (attention_head_dim // 6))
400-
], dim=1)
401-
402364
self.gradient_checkpointing = False
403365

404366
def forward(
@@ -409,14 +371,15 @@ def forward(
409371
encoder_hidden_states_image: Optional[torch.Tensor] = None,
410372
return_dict: bool = True,
411373
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
412-
if self.freqs.device != hidden_states.device:
413-
self.freqs = self.freqs.to(hidden_states.device)
374+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
375+
p_t, p_h, p_w = self.config.patch_size
376+
post_patch_num_frames = num_frames // p_t
377+
post_patch_height = height // p_h
378+
post_patch_width = width // p_w
379+
380+
rotary_emb = self.rope(hidden_states)
414381

415382
hidden_states = self.patch_embedding(hidden_states)
416-
417-
grid_sizes = torch.stack(
418-
[torch.tensor(u.shape[1:], dtype=torch.long) for u in hidden_states]
419-
)
420383
hidden_states = hidden_states.flatten(2).transpose(1, 2)
421384

422385
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image)
@@ -428,49 +391,21 @@ def forward(
428391
# 4. Transformer blocks
429392
if torch.is_grad_enabled() and self.gradient_checkpointing:
430393
for block in self.blocks:
431-
hidden_states = self._gradient_checkpointing_func(
432-
block,
433-
hidden_states,
434-
timestep_proj,
435-
encoder_hidden_states,
436-
grid_sizes,
437-
self.freqs,
438-
)
394+
hidden_states = self._gradient_checkpointing_func(block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
439395
else:
440396
for block in self.blocks:
441-
hidden_states = block(
442-
hidden_states,
443-
timestep_proj,
444-
encoder_hidden_states,
445-
grid_sizes,
446-
self.freqs,
447-
)
448-
449-
# Output projection
450-
with torch.amp.autocast(str(hidden_states.device), dtype=torch.float32):
451-
temb = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
452-
hidden_states = self.norm_out(hidden_states) * (1 + temb[1]) + temb[0]
453-
hidden_states = self.proj_out(hidden_states)
454-
455-
hidden_states = hidden_states.type_as(encoder_hidden_states)
456-
457-
# 5. Unpatchify
458-
hidden_states = self.unpatchify(hidden_states, grid_sizes)
459-
hidden_states = torch.stack(hidden_states)
397+
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
460398

461-
if not return_dict:
462-
return (hidden_states,)
399+
# 5. Output norm, projection & unpatchify
400+
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
401+
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
402+
hidden_states = self.proj_out(hidden_states)
463403

464-
return Transformer2DModelOutput(sample=hidden_states)
404+
hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1)
405+
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
406+
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
465407

408+
if not return_dict:
409+
return (output,)
466410

467-
def unpatchify(self, hidden_states, grid_sizes):
468-
c = self.out_channels
469-
out = []
470-
for u, v in zip(hidden_states, grid_sizes.tolist()):
471-
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
472-
u = torch.einsum('fhwpqrc->cfphqwr', u)
473-
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
474-
out.append(u)
475-
return out
476-
411+
return Transformer2DModelOutput(sample=output)

0 commit comments

Comments
 (0)