2222from ...configuration_utils import ConfigMixin , register_to_config
2323from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
2424from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
25- from ..attention import FeedForward
25+ from ... utils . torch_utils import maybe_allow_in_graph
2626from ..attention import AttentionMixin , AttentionModuleMixin , FeedForward
2727from ..attention_dispatch import dispatch_attention_fn
2828from ..cache_utils import CacheMixin
3939
4040logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
4141
42+
4243# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections
43- def _get_qkv_projections (attn : "SkyReelsV2Attention" , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ):
44+ def _get_qkv_projections (
45+ attn : "SkyReelsV2Attention" , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor
46+ ):
4447 # encoder_hidden_states is only passed for cross-attention
4548 if encoder_hidden_states is None :
4649 encoder_hidden_states = hidden_states
@@ -455,7 +458,7 @@ def forward(
455458 # For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim)
456459 e = (self .scale_shift_table .unsqueeze (2 ) + temb .float ()).chunk (6 , dim = 1 )
457460 shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = [ei .squeeze (1 ) for ei in e ]
458-
461+
459462 # 1. Self-attention
460463 norm_hidden_states = (self .norm1 (hidden_states .float ()) * (1 + scale_msa ) + shift_msa ).type_as (hidden_states )
461464 attn_output = self .attn1 (norm_hidden_states , None , attention_mask , rotary_emb )
@@ -476,7 +479,9 @@ def forward(
476479 return hidden_states
477480
478481
479- class SkyReelsV2Transformer3DModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin , AttentionMixin ):
482+ class SkyReelsV2Transformer3DModel (
483+ ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin , AttentionMixin
484+ ):
480485 r"""
481486 A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.
482487
0 commit comments