Skip to content

Commit 06d3a62

Browse files
committed
refactor: remove copied comments from transformer_wan in SkyReelsV2 classes
1 parent e2f328b commit 06d3a62

File tree

1 file changed

+0
-6
lines changed

1 file changed

+0
-6
lines changed

src/diffusers/models/transformers/transformer_skyreels_v2.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def _get_qkv_projections(
6262
return query, key, value
6363

6464

65-
# Copied from diffusers.models.transformers.transformer_wan._get_added_kv_projections
6665
def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states_img: torch.Tensor):
6766
if attn.fused_projections:
6867
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
@@ -211,7 +210,6 @@ def __init__(
211210

212211
self.set_processor(processor)
213212

214-
# Copied from diffusers.models.transformers.transformer_wan.WanAttention.fuse_projections
215213
def fuse_projections(self):
216214
if getattr(self, "fused_projections", False):
217215
return
@@ -248,7 +246,6 @@ def fuse_projections(self):
248246
self.fused_projections = True
249247

250248
@torch.no_grad()
251-
# Copied from diffusers.models.transformers.transformer_wan.WanAttention.unfuse_projections
252249
def unfuse_projections(self):
253250
if not getattr(self, "fused_projections", False):
254251
return
@@ -262,7 +259,6 @@ def unfuse_projections(self):
262259

263260
self.fused_projections = False
264261

265-
# Copied from diffusers.models.transformers.transformer_wan.WanAttention.forward
266262
def forward(
267263
self,
268264
hidden_states: torch.Tensor,
@@ -274,7 +270,6 @@ def forward(
274270
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
275271

276272

277-
# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with Wan -> SkyReelsV2
278273
class SkyReelsV2ImageEmbedding(torch.nn.Module):
279274
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
280275
super().__init__()
@@ -363,7 +358,6 @@ def forward(
363358
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
364359

365360

366-
# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed with Wan -> SkyReelsV2
367361
class SkyReelsV2RotaryPosEmbed(nn.Module):
368362
def __init__(
369363
self,

0 commit comments

Comments
 (0)