Skip to content

Commit 7e237ad

Browse files
committed
style
1 parent 42113fc commit 7e237ad

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

src/diffusers/models/transformers/transformer_skyreels_v2.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
2424
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25-
from ..attention import FeedForward
25+
from ...utils.torch_utils import maybe_allow_in_graph
2626
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
2727
from ..attention_dispatch import dispatch_attention_fn
2828
from ..cache_utils import CacheMixin
@@ -39,8 +39,11 @@
3939

4040
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4141

42+
4243
# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections
43-
def _get_qkv_projections(attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
44+
def _get_qkv_projections(
45+
attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor
46+
):
4447
# encoder_hidden_states is only passed for cross-attention
4548
if encoder_hidden_states is None:
4649
encoder_hidden_states = hidden_states
@@ -455,7 +458,7 @@ def forward(
455458
# For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim)
456459
e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1)
457460
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e]
458-
461+
459462
# 1. Self-attention
460463
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
461464
attn_output = self.attn1(norm_hidden_states, None, attention_mask, rotary_emb)
@@ -476,7 +479,9 @@ def forward(
476479
return hidden_states
477480

478481

479-
class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin):
482+
class SkyReelsV2Transformer3DModel(
483+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
484+
):
480485
r"""
481486
A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.
482487

0 commit comments

Comments
 (0)