Skip to content

Commit a7e7b2f

Browse files
committed
Refactor rotary embedding calculations in SkyReelsV2 to separate cosine and sine frequencies
1 parent 6856ee6 commit a7e7b2f

File tree

1 file changed

+56
-27
lines changed

1 file changed

+56
-27
lines changed

src/diffusers/models/transformers/transformer_skyreels_v2.py

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,21 @@ def __call__(
108108

109109
if rotary_emb is not None:
110110

111-
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
112-
x_rotated = torch.view_as_complex(hidden_states.to(torch.float32).unflatten(3, (-1, 2)))
113-
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
114-
return x_out.type_as(hidden_states)
115-
116-
query = apply_rotary_emb(query, rotary_emb)
117-
key = apply_rotary_emb(key, rotary_emb)
111+
def apply_rotary_emb(
112+
hidden_states: torch.Tensor,
113+
freqs_cos: torch.Tensor,
114+
freqs_sin: torch.Tensor,
115+
):
116+
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
117+
cos = freqs_cos[..., 0::2]
118+
sin = freqs_sin[..., 1::2]
119+
out = torch.empty_like(hidden_states)
120+
out[..., 0::2] = x1 * cos - x2 * sin
121+
out[..., 1::2] = x1 * sin + x2 * cos
122+
return out.type_as(hidden_states)
123+
124+
query = apply_rotary_emb(query, *rotary_emb)
125+
key = apply_rotary_emb(key, *rotary_emb)
118126

119127
# I2V task
120128
hidden_states_img = None
@@ -358,7 +366,11 @@ def forward(
358366

359367
class SkyReelsV2RotaryPosEmbed(nn.Module):
360368
def __init__(
361-
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
369+
self,
370+
attention_head_dim: int,
371+
patch_size: Tuple[int, int, int],
372+
max_seq_len: int,
373+
theta: float = 10000.0,
362374
):
363375
super().__init__()
364376

@@ -368,35 +380,52 @@ def __init__(
368380

369381
h_dim = w_dim = 2 * (attention_head_dim // 6)
370382
t_dim = attention_head_dim - h_dim - w_dim
383+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
384+
385+
freqs_cos = []
386+
freqs_sin = []
371387

372-
freqs = []
373388
for dim in [t_dim, h_dim, w_dim]:
374-
freq = get_1d_rotary_pos_embed(
375-
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float32
389+
freq_cos, freq_sin = get_1d_rotary_pos_embed(
390+
dim,
391+
max_seq_len,
392+
theta,
393+
use_real=True,
394+
repeat_interleave_real=True,
395+
freqs_dtype=freqs_dtype,
376396
)
377-
freqs.append(freq)
378-
self.freqs = torch.cat(freqs, dim=1)
397+
freqs_cos.append(freq_cos)
398+
freqs_sin.append(freq_sin)
399+
400+
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
401+
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
379402

380403
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
381404
batch_size, num_channels, num_frames, height, width = hidden_states.shape
382405
p_t, p_h, p_w = self.patch_size
383406
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
384407

385-
freqs = self.freqs.to(hidden_states.device)
386-
freqs = freqs.split_with_sizes(
387-
[
388-
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
389-
self.attention_head_dim // 6,
390-
self.attention_head_dim // 6,
391-
],
392-
dim=1,
393-
)
408+
split_sizes = [
409+
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
410+
self.attention_head_dim // 3,
411+
self.attention_head_dim // 3,
412+
]
413+
414+
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
415+
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
416+
417+
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
418+
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
419+
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
420+
421+
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
422+
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
423+
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
424+
425+
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
426+
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
394427

395-
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
396-
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
397-
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
398-
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
399-
return freqs
428+
return freqs_cos, freqs_sin
400429

401430

402431
@maybe_allow_in_graph

0 commit comments

Comments
 (0)