Skip to content

Commit 1c55871

Browse files
committed
Merge branch 'main' into model-test-refactor
2 parents 1f026ad + 0c75892 commit 1c55871

File tree

4 files changed

+90
-89
lines changed

4 files changed

+90
-89
lines changed

src/diffusers/models/transformers/transformer_sana_video.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ def apply_rotary_emb(
172172
return hidden_states
173173

174174

175-
# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed
176175
class WanRotaryPosEmbed(nn.Module):
177176
def __init__(
178177
self,

src/diffusers/models/transformers/transformer_skyreels_v2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,10 @@ def __init__(
389389
t_dim = attention_head_dim - h_dim - w_dim
390390
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
391391

392+
self.t_dim = t_dim
393+
self.h_dim = h_dim
394+
self.w_dim = w_dim
395+
392396
freqs_cos = []
393397
freqs_sin = []
394398

@@ -412,11 +416,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
412416
p_t, p_h, p_w = self.patch_size
413417
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
414418

415-
split_sizes = [
416-
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
417-
self.attention_head_dim // 3,
418-
self.attention_head_dim // 3,
419-
]
419+
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
420420

421421
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
422422
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,11 @@ def __init__(
362362

363363
h_dim = w_dim = 2 * (attention_head_dim // 6)
364364
t_dim = attention_head_dim - h_dim - w_dim
365+
366+
self.t_dim = t_dim
367+
self.h_dim = h_dim
368+
self.w_dim = w_dim
369+
365370
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
366371

367372
freqs_cos = []
@@ -387,11 +392,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
387392
p_t, p_h, p_w = self.patch_size
388393
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
389394

390-
split_sizes = [
391-
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
392-
self.attention_head_dim // 3,
393-
self.attention_head_dim // 3,
394-
]
395+
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
395396

396397
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
397398
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)

0 commit comments

Comments
 (0)