Skip to content
Open
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
89f39be
Refactor output norm to use AdaLayerNorm in Wan transformers
tolgacangoz Jul 2, 2025
e4b30b8
fix: remove scale_shift_table from _keep_in_fp32_modules in Wan and W…
tolgacangoz Jul 2, 2025
92f8237
Fixes transformed head modulation layer mapping
tolgacangoz Jul 2, 2025
df07b88
Fix: Revert removing `scale_shift_table` from `_keep_in_fp32_modules`…
tolgacangoz Jul 2, 2025
e555903
Refactors transformer output blocks to use AdaLayerNorm
tolgacangoz Jul 2, 2025
921396a
Fix `head.modulation` mapping in conversion script
tolgacangoz Jul 3, 2025
ff95d5d
Fix handling of missing bias keys in conversion script
tolgacangoz Jul 3, 2025
3fd6f4e
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Jul 3, 2025
42c1451
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Jul 10, 2025
3178c4e
add backwardability
tolgacangoz Jul 10, 2025
65639d5
style
tolgacangoz Jul 10, 2025
6e6dfa8
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Jul 18, 2025
5f5c9dd
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Aug 16, 2025
51ecf1a
Adds option to disable SiLU in AdaLayerNorm
tolgacangoz Aug 16, 2025
3022b10
Refactors output normalization in `SkyReelsV2Transformer3DModel`
tolgacangoz Aug 16, 2025
35d8b3a
revert
tolgacangoz Aug 17, 2025
e0fe837
Implement `SkyReelsV2AdaLayerNorm` for timestep embedding modulation …
tolgacangoz Aug 17, 2025
1f649b4
style
tolgacangoz Aug 17, 2025
7b2b0f4
Enhances backward compatibility by converting deprecated `scale_shift…
tolgacangoz Aug 17, 2025
374530f
Adds missing keys to ignore on load and adjusts identity matrix creat…
tolgacangoz Aug 17, 2025
9c7ac6e
Fix device and dtype handling for identity matrix and scale_shift_tab…
tolgacangoz Aug 17, 2025
f8716e3
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Aug 17, 2025
df9efae
Update normalization.py
tolgacangoz Aug 17, 2025
b27c0bf
Delete src/diffusers/models/transformers/latte_transformer_3d.py
tolgacangoz Aug 18, 2025
05a8ce1
Revert currently unrelated ones
tolgacangoz Aug 18, 2025
7e9b9c0
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Aug 18, 2025
73c7189
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Aug 19, 2025
6bbf350
Remove `_load_from_state_dict`
tolgacangoz Aug 19, 2025
c0816c6
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Aug 20, 2025
d9b32df
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Aug 26, 2025
cb59d88
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 66 additions & 32 deletions src/diffusers/models/transformers/transformer_skyreels_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,62 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class SkyReelsV2AdaLayerNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.

Parameters:
embedding_dim (`int`): The size of each embedding vector.
output_dim (`int`, *optional*):
norm_elementwise_affine (`bool`, defaults to `False):
norm_eps (`bool`, defaults to `False`):
"""

def __init__(
self,
embedding_dim: int,
output_dim: Optional[int] = None,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
):
super().__init__()

output_dim = output_dim or embedding_dim * 2

self.linear = nn.Linear(embedding_dim, output_dim)
self.linear.weight.data[:embedding_dim, :] = torch.eye(embedding_dim)
self.linear.weight.data[embedding_dim:, :] = torch.eye(embedding_dim)
self.norm = FP32LayerNorm(embedding_dim, norm_eps, norm_elementwise_affine)

def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
if temb.ndim == 2:
# If temb is 2D, we assume it has 1-D time embedding values for each batch.
# For models:
# - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
# - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
# - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
# - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
# - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
# 2D temb: (batch, embedding_dim)
temb = self.linear(temb.unsqueeze(1)) # (batch, 1, embedding_dim * 2)
shift, scale = temb.chunk(2, dim=2)
elif temb.ndim == 3:
# If temb is 3D, we assume it has 2-D time embedding values for each batch.
# Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing.
# For models:
# - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
# - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
# - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
# 3D temb: (batch, num_latent_frames * post_patch_height * post_patch_width, embedding_dim)
temb = self.linear(
temb
) # (batch, num_latent_frames * post_patch_height * post_patch_width, embedding_dim * 2)
shift, scale = temb.chunk(2, dim=2)

x = self.norm(x) * (1 + scale) + shift
return x


def _get_qkv_projections(
attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor
):
Expand Down Expand Up @@ -559,9 +615,10 @@ class SkyReelsV2Transformer3DModel(

_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["SkyReelsV2TransformerBlock"]
_no_split_modules = ["SkyReelsV2TransformerBlock", "SkyReelsV2AdaLayerNorm"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q", "scale_shift_table"]
_keys_to_ignore_on_load_missing = ["norm_out.linear.weight", "norm_out.linear.bias"]
_repeated_blocks = ["SkyReelsV2TransformerBlock"]

@register_to_config
Expand Down Expand Up @@ -617,9 +674,13 @@ def __init__(
)

# 4. Output norm & projection
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.norm_out = SkyReelsV2AdaLayerNorm(
embedding_dim=inner_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=False,
norm_eps=eps,
)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)

if inject_sample_info:
self.fps_embedding = nn.Embedding(2, inner_dim)
Expand Down Expand Up @@ -732,34 +793,7 @@ def forward(
causal_mask,
)

if temb.dim() == 2:
# If temb is 2D, we assume it has time 1-D time embedding values for each batch.
# For models:
# - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
# - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
# - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
# - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
# - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
elif temb.dim() == 3:
# If temb is 3D, we assume it has 2-D time embedding values for each batch.
# Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing.
# For models:
# - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
# - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
# - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
shift, scale = (self.scale_shift_table.unsqueeze(2) + temb.unsqueeze(1)).chunk(2, dim=1)
shift, scale = shift.squeeze(1), scale.squeeze(1)

# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)

hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)

hidden_states = self.norm_out(hidden_states, temb=temb)
hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states.reshape(
Expand Down