Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
89f39be
Refactor output norm to use AdaLayerNorm in Wan transformers
tolgacangoz Jul 2, 2025
e4b30b8
fix: remove scale_shift_table from _keep_in_fp32_modules in Wan and W…
tolgacangoz Jul 2, 2025
92f8237
Fixes transformed head modulation layer mapping
tolgacangoz Jul 2, 2025
df07b88
Fix: Revert removing `scale_shift_table` from `_keep_in_fp32_modules`…
tolgacangoz Jul 2, 2025
e555903
Refactors transformer output blocks to use AdaLayerNorm
tolgacangoz Jul 2, 2025
921396a
Fix `head.modulation` mapping in conversion script
tolgacangoz Jul 3, 2025
ff95d5d
Fix handling of missing bias keys in conversion script
tolgacangoz Jul 3, 2025
3fd6f4e
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Jul 3, 2025
42c1451
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Jul 10, 2025
3178c4e
add backwardability
tolgacangoz Jul 10, 2025
65639d5
style
tolgacangoz Jul 10, 2025
6e6dfa8
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Jul 18, 2025
5f5c9dd
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Aug 16, 2025
51ecf1a
Adds option to disable SiLU in AdaLayerNorm
tolgacangoz Aug 16, 2025
3022b10
Refactors output normalization in `SkyReelsV2Transformer3DModel`
tolgacangoz Aug 16, 2025
35d8b3a
revert
tolgacangoz Aug 17, 2025
e0fe837
Implement `SkyReelsV2AdaLayerNorm` for timestep embedding modulation …
tolgacangoz Aug 17, 2025
1f649b4
style
tolgacangoz Aug 17, 2025
7b2b0f4
Enhances backward compatibility by converting deprecated `scale_shift…
tolgacangoz Aug 17, 2025
374530f
Adds missing keys to ignore on load and adjusts identity matrix creat…
tolgacangoz Aug 17, 2025
9c7ac6e
Fix device and dtype handling for identity matrix and scale_shift_tab…
tolgacangoz Aug 17, 2025
f8716e3
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Aug 17, 2025
df9efae
Update normalization.py
tolgacangoz Aug 17, 2025
b27c0bf
Delete src/diffusers/models/transformers/latte_transformer_3d.py
tolgacangoz Aug 18, 2025
05a8ce1
Revert currently unrelated ones
tolgacangoz Aug 18, 2025
7e9b9c0
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Aug 18, 2025
73c7189
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Aug 19, 2025
6bbf350
Remove `_load_from_state_dict`
tolgacangoz Aug 19, 2025
c0816c6
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Aug 20, 2025
d9b32df
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Aug 26, 2025
cb59d88
Merge branch 'main' into transfer-shift_scale_norm-to-AdaLayerNorm
tolgacangoz Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/diffusers/models/transformers/latte_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
from ..normalization import AdaLayerNorm, AdaLayerNormSingle


class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["norm_out"]

"""
A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code:
Expand Down Expand Up @@ -149,8 +150,13 @@ def __init__(

# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.norm_out = AdaLayerNorm(
embedding_dim=inner_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=False,
norm_eps=1e-6,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)

# 5. Latte other blocks.
Expand All @@ -165,6 +171,17 @@ def __init__(

self.gradient_checkpointing = False

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if "scale_shift_table" in state_dict:
scale_shift_table = state_dict.pop("scale_shift_table")
state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1]
state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0]
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -305,10 +322,7 @@ def forward(
embedded_timestep = embedded_timestep.repeat_interleave(
num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
).view(-1, embedded_timestep.shape[-1])
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.norm_out(hidden_states, temb=embedded_timestep)
hidden_states = self.proj_out(hidden_states)

# unpatchify
Expand Down
31 changes: 21 additions & 10 deletions src/diffusers/models/transformers/pixart_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
from ..normalization import AdaLayerNorm, AdaLayerNormSingle


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -78,7 +78,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
"""

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed", "norm_out"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]

@register_to_config
Expand Down Expand Up @@ -171,8 +171,13 @@ def __init__(
)

# 3. Output blocks.
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.norm_out = AdaLayerNorm(
embedding_dim=self.inner_dim,
output_dim=2 * self.inner_dim,
norm_elementwise_affine=False,
norm_eps=1e-6,
chunk_dim=1,
)
self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)

self.adaln_single = AdaLayerNormSingle(
Expand All @@ -184,6 +189,17 @@ def __init__(
in_features=self.config.caption_channels, hidden_size=self.inner_dim
)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if "scale_shift_table" in state_dict:
scale_shift_table = state_dict.pop("scale_shift_table")
state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1]
state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0]
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
Expand Down Expand Up @@ -406,12 +422,7 @@ def forward(
)

# 3. Output
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
hidden_states = self.norm_out(hidden_states, temb=embedded_timestep)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)

Expand Down
29 changes: 21 additions & 8 deletions src/diffusers/models/transformers/transformer_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
from ..normalization import AdaLayerNorm, AdaLayerNormSingle


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -175,6 +175,7 @@ def forward(

class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["norm_out"]

"""
A 3D Transformer model for video-like data.
Expand Down Expand Up @@ -292,8 +293,13 @@ def __init__(
)

# 3. Output projection & norm
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.norm_out = AdaLayerNorm(
embedding_dim=self.inner_dim,
output_dim=2 * self.inner_dim,
norm_elementwise_affine=False,
norm_eps=1e-6,
chunk_dim=1,
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)

# 4. Timestep embeddings
Expand All @@ -304,6 +310,17 @@ def __init__(

self.gradient_checkpointing = False

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
if "scale_shift_table" in state_dict:
scale_shift_table = state_dict.pop("scale_shift_table")
state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1]
state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0]
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -393,11 +410,7 @@ def forward(
)

# 4. Output normalization & projection
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)

# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.norm_out(hidden_states, temb=embedded_timestep)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)

Expand Down
41 changes: 33 additions & 8 deletions src/diffusers/models/transformers/transformer_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..embeddings import PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
from ..normalization import AdaLayerNorm, AdaLayerNormSingle, RMSNorm


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -408,6 +408,7 @@ class LTXVideoTransformer3DModel(

_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
_no_split_modules = ["norm_out"]
_repeated_blocks = ["LTXVideoTransformerBlock"]

@register_to_config
Expand Down Expand Up @@ -436,7 +437,6 @@ def __init__(

self.proj_in = nn.Linear(in_channels, inner_dim)

self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)

self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
Expand Down Expand Up @@ -469,11 +469,40 @@ def __init__(
]
)

self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
self.norm_out = AdaLayerNorm(
embedding_dim=inner_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=False,
norm_eps=1e-6,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, out_channels)

self.gradient_checkpointing = False

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
key = "scale_shift_table"
if prefix + key in state_dict:
scale_shift_table = state_dict.pop(prefix + key)
inner_dim = scale_shift_table.shape[-1]

weight = torch.eye(inner_dim).repeat(2, 1)
bias = scale_shift_table.reshape(2, inner_dim).flatten()

state_dict[prefix + "norm_out.linear.weight"] = weight
state_dict[prefix + "norm_out.linear.bias"] = bias

if prefix + "norm_out.weight" in state_dict:
state_dict.pop(prefix + "norm_out.weight")
if prefix + "norm_out.bias" in state_dict:
state_dict.pop(prefix + "norm_out.bias")

return super(LTXVideoTransformer3DModel, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -544,11 +573,7 @@ def forward(
encoder_attention_mask=encoder_attention_mask,
)

scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]

hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.norm_out(hidden_states, temb=embedded_timestep.squeeze(1))
output = self.proj_out(hidden_states)

if USE_PEFT_BACKEND:
Expand Down
Loading