@@ -62,7 +62,6 @@ def _get_qkv_projections(
6262 return query , key , value
6363
6464
65- # Copied from diffusers.models.transformers.transformer_wan._get_added_kv_projections
6665def _get_added_kv_projections (attn : "SkyReelsV2Attention" , encoder_hidden_states_img : torch .Tensor ):
6766 if attn .fused_projections :
6867 key_img , value_img = attn .to_added_kv (encoder_hidden_states_img ).chunk (2 , dim = - 1 )
@@ -211,7 +210,6 @@ def __init__(
211210
212211 self .set_processor (processor )
213212
214- # Copied from diffusers.models.transformers.transformer_wan.WanAttention.fuse_projections
215213 def fuse_projections (self ):
216214 if getattr (self , "fused_projections" , False ):
217215 return
@@ -248,7 +246,6 @@ def fuse_projections(self):
248246 self .fused_projections = True
249247
250248 @torch .no_grad ()
251- # Copied from diffusers.models.transformers.transformer_wan.WanAttention.unfuse_projections
252249 def unfuse_projections (self ):
253250 if not getattr (self , "fused_projections" , False ):
254251 return
@@ -262,7 +259,6 @@ def unfuse_projections(self):
262259
263260 self .fused_projections = False
264261
265- # Copied from diffusers.models.transformers.transformer_wan.WanAttention.forward
266262 def forward (
267263 self ,
268264 hidden_states : torch .Tensor ,
@@ -274,7 +270,6 @@ def forward(
274270 return self .processor (self , hidden_states , encoder_hidden_states , attention_mask , rotary_emb , ** kwargs )
275271
276272
277- # Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with Wan -> SkyReelsV2
278273class SkyReelsV2ImageEmbedding (torch .nn .Module ):
279274 def __init__ (self , in_features : int , out_features : int , pos_embed_seq_len = None ):
280275 super ().__init__ ()
@@ -363,7 +358,6 @@ def forward(
363358 return temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image
364359
365360
366- # Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed with Wan -> SkyReelsV2
367361class SkyReelsV2RotaryPosEmbed (nn .Module ):
368362 def __init__ (
369363 self ,
0 commit comments