From 89f39bec13b8e027e11d8aecbe77b6f13f1ba771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 2 Jul 2025 18:55:17 +0300 Subject: [PATCH 01/21] Refactor output norm to use AdaLayerNorm in Wan transformers Replace the final `FP32LayerNorm` and manual shift/scale application with a single `AdaLayerNorm` module in both the `WanTransformer3DModel` and `WanVACETransformer3DModel`. This change simplifies the forward pass by encapsulating the adaptive normalization logic within the `AdaLayerNorm` layer, removing the need for a separate `scale_shift_table`. The `_no_split_modules` list is also updated to include `norm_out` for compatibility with model parallelism. --- .../models/transformers/transformer_wan.py | 24 ++++++++----------- .../transformers/transformer_wan_vace.py | 24 ++++++++----------- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 5fb71b69f7ac..9702d47a9cf5 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -28,7 +28,7 @@ from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import FP32LayerNorm +from ..normalization import AdaLayerNorm, FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -370,7 +370,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] - _no_split_modules = ["WanTransformerBlock"] + _no_split_modules = ["WanTransformerBlock", "norm_out"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["WanTransformerBlock"] @@ -426,9 +426,14 @@ def __init__( ) # 4. Output norm & projection - self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.norm_out = AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=eps, + chunk_dim=1, + ) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) - self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False @@ -486,16 +491,7 @@ def forward( hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) # 5. Output norm, projection & unpatchify - shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) - - # Move the shift and scale tensors to the same device as hidden_states. - # When using multi-GPU inference via accelerate these will be on the - # first device rather than the last device, which hidden_states ends up - # on. - shift = shift.to(hidden_states.device) - scale = scale.to(hidden_states.device) - - hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.norm_out(hidden_states, temb=temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 1a6f2af59a87..0813051cf5fc 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -26,7 +26,7 @@ from ..cache_utils import CacheMixin from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import FP32LayerNorm +from ..normalization import AdaLayerNorm, FP32LayerNorm from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock @@ -179,7 +179,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"] - _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"] + _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock", "norm_out"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @@ -259,9 +259,14 @@ def __init__( ) # 4. Output norm & projection - self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.norm_out = AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=eps, + chunk_dim=1, + ) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) - self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False @@ -365,16 +370,7 @@ def forward( hidden_states = hidden_states + control_hint * scale # 6. Output norm, projection & unpatchify - shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) - - # Move the shift and scale tensors to the same device as hidden_states. - # When using multi-GPU inference via accelerate these will be on the - # first device rather than the last device, which hidden_states ends up - # on. - shift = shift.to(hidden_states.device) - scale = scale.to(hidden_states.device) - - hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.norm_out(hidden_states, temb=temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( From e4b30b88beb170fdcdcd36d6a5c1fe76ba1cb2fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 2 Jul 2025 19:51:34 +0300 Subject: [PATCH 02/21] fix: remove scale_shift_table from _keep_in_fp32_modules in Wan and WanVACE transformers --- src/diffusers/models/transformers/transformer_wan.py | 2 +- src/diffusers/models/transformers/transformer_wan_vace.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 9702d47a9cf5..d5e2fe2aea31 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -371,7 +371,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["WanTransformerBlock", "norm_out"] - _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keep_in_fp32_modules = ["time_embedder", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["WanTransformerBlock"] diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 0813051cf5fc..2a6a64032f5a 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -180,7 +180,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock", "norm_out"] - _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keep_in_fp32_modules = ["time_embedder", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @register_to_config From 92f8237638bc18a6539cbcb6c72461ef3dac1728 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 2 Jul 2025 20:21:05 +0300 Subject: [PATCH 03/21] Fixes transformed head modulation layer mapping Updates the key mapping for the `head.modulation` layer to `norm_out.linear` in the model conversion script. This correction ensures that weights are loaded correctly for both standard and VACE transformer models. --- scripts/convert_wan_to_diffusers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 6d25cde071b1..24cb798cc198 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -25,7 +25,7 @@ "text_embedding.0": "condition_embedder.text_embedder.linear_1", "text_embedding.2": "condition_embedder.text_embedder.linear_2", "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "scale_shift_table", + "head.modulation": "norm_out.linear", "head.head": "proj_out", "modulation": "scale_shift_table", "ffn.0": "ffn.net.0.proj", @@ -67,7 +67,7 @@ "text_embedding.0": "condition_embedder.text_embedder.linear_1", "text_embedding.2": "condition_embedder.text_embedder.linear_2", "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "scale_shift_table", + "head.modulation": "norm_out.linear", "head.head": "proj_out", "modulation": "scale_shift_table", "ffn.0": "ffn.net.0.proj", From df07b88b0fac437dd1320ea2d67988017940dad0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 2 Jul 2025 20:32:49 +0300 Subject: [PATCH 04/21] Fix: Revert removing `scale_shift_table` from `_keep_in_fp32_modules` in Wan and WanVACE transformers --- src/diffusers/models/transformers/transformer_wan.py | 2 +- src/diffusers/models/transformers/transformer_wan_vace.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index d5e2fe2aea31..9702d47a9cf5 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -371,7 +371,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["WanTransformerBlock", "norm_out"] - _keep_in_fp32_modules = ["time_embedder", "norm1", "norm2", "norm3"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["WanTransformerBlock"] diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 2a6a64032f5a..0813051cf5fc 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -180,7 +180,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock", "norm_out"] - _keep_in_fp32_modules = ["time_embedder", "norm1", "norm2", "norm3"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @register_to_config From e555903067c43d13fbafa84d0f39d918a408e7d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 2 Jul 2025 21:22:17 +0300 Subject: [PATCH 05/21] Refactors transformer output blocks to use AdaLayerNorm Replaces the manual implementation of adaptive layer normalization, which used a separate `scale_shift_table` and `nn.LayerNorm`, with the unified `AdaLayerNorm` module. This change simplifies the forward pass logic in several transformer models by encapsulating the normalization and modulation steps into a single component. It also adds `norm_out` to `_no_split_modules` for model parallelism compatibility. --- .../transformers/latte_transformer_3d.py | 17 +++++++++------- .../transformers/pixart_transformer_2d.py | 20 +++++++++---------- .../transformers/transformer_allegro.py | 18 +++++++++-------- .../models/transformers/transformer_ltx.py | 18 +++++++++-------- 4 files changed, 40 insertions(+), 33 deletions(-) diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 990c90512e39..486969047d6a 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -23,11 +23,12 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle +from ..normalization import AdaLayerNorm, AdaLayerNormSingle class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): _supports_gradient_checkpointing = True + _no_split_modules = ["norm_out"] """ A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code: @@ -149,8 +150,13 @@ def __init__( # 4. Define output layers self.out_channels = in_channels if out_channels is None else out_channels - self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.norm_out = AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=1e-6, + chunk_dim=1, + ) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) # 5. Latte other blocks. @@ -305,10 +311,7 @@ def forward( embedded_timestep = embedded_timestep.repeat_interleave( num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame ).view(-1, embedded_timestep.shape[-1]) - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.norm_out(hidden_states, temb=embedded_timestep) hidden_states = self.proj_out(hidden_states) # unpatchify diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 40a14bfd9b27..f05da448367d 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -23,7 +23,7 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle +from ..normalization import AdaLayerNorm, AdaLayerNormSingle logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -78,7 +78,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] + _no_split_modules = ["BasicTransformerBlock", "PatchEmbed", "norm_out"] _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"] @register_to_config @@ -171,8 +171,13 @@ def __init__( ) # 3. Output blocks. - self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.norm_out = AdaLayerNorm( + embedding_dim=self.inner_dim, + output_dim=2 * self.inner_dim, + norm_elementwise_affine=False, + norm_eps=1e-6, + chunk_dim=1, + ) self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) self.adaln_single = AdaLayerNormSingle( @@ -406,12 +411,7 @@ def forward( ) # 3. Output - shift, scale = ( - self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) - ).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - # Modulation - hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) + hidden_states = self.norm_out(hidden_states, temb=embedded_timestep) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 5fa59a71d977..a20c64801de5 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -28,7 +28,7 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle +from ..normalization import AdaLayerNorm, AdaLayerNormSingle logger = logging.get_logger(__name__) @@ -175,6 +175,7 @@ def forward( class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): _supports_gradient_checkpointing = True + _no_split_modules = ["norm_out"] """ A 3D Transformer model for video-like data. @@ -292,8 +293,13 @@ def __init__( ) # 3. Output projection & norm - self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) - self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.norm_out = AdaLayerNorm( + embedding_dim=self.inner_dim, + output_dim=2 * self.inner_dim, + norm_elementwise_affine=False, + norm_eps=1e-6, + chunk_dim=1, + ) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels) # 4. Timestep embeddings @@ -393,11 +399,7 @@ def forward( ) # 4. Output normalization & projection - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) - - # Modulation - hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.norm_out(hidden_states, temb=embedded_timestep) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 2d06124282d1..91e40c8f3471 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -30,7 +30,7 @@ from ..embeddings import PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormSingle, RMSNorm +from ..normalization import AdaLayerNorm, AdaLayerNormSingle, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] + _no_split_modules = ["norm_out"] _repeated_blocks = ["LTXVideoTransformerBlock"] @register_to_config @@ -356,7 +357,6 @@ def __init__( self.proj_in = nn.Linear(in_channels, inner_dim) - self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) @@ -389,7 +389,13 @@ def __init__( ] ) - self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.norm_out = AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=1e-6, + chunk_dim=1, + ) self.proj_out = nn.Linear(inner_dim, out_channels) self.gradient_checkpointing = False @@ -464,11 +470,7 @@ def forward( encoder_attention_mask=encoder_attention_mask, ) - scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.norm_out(hidden_states, temb=embedded_timestep.squeeze(1)) output = self.proj_out(hidden_states) if USE_PEFT_BACKEND: From 921396a147868486a5662e10e3285bdd5ee35550 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 3 Jul 2025 18:37:20 +0300 Subject: [PATCH 06/21] Fix `head.modulation` mapping in conversion script Corrects the target key for `head.modulation` to `norm_out.linear.weight`. This ensures the weights are correctly mapped to the weight parameter of the output normalization layer during model conversion for both transformer types. --- scripts/convert_wan_to_diffusers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 24cb798cc198..cf287bad0149 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -25,7 +25,7 @@ "text_embedding.0": "condition_embedder.text_embedder.linear_1", "text_embedding.2": "condition_embedder.text_embedder.linear_2", "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "norm_out.linear", + "head.modulation": "norm_out.linear.weight", "head.head": "proj_out", "modulation": "scale_shift_table", "ffn.0": "ffn.net.0.proj", @@ -67,7 +67,7 @@ "text_embedding.0": "condition_embedder.text_embedder.linear_1", "text_embedding.2": "condition_embedder.text_embedder.linear_2", "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "norm_out.linear", + "head.modulation": "norm_out.linear.weight", "head.head": "proj_out", "modulation": "scale_shift_table", "ffn.0": "ffn.net.0.proj", From ff95d5db6badb99242006733008d08eda1050ee7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 3 Jul 2025 19:51:01 +0300 Subject: [PATCH 07/21] Fix handling of missing bias keys in conversion script Adds a default zero-initialized bias tensor for the transformer's output normalization layer if it is missing from the original state dictionary. --- scripts/convert_wan_to_diffusers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index cf287bad0149..012982fe8f97 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -105,8 +105,12 @@ "after_proj": "proj_out", } -TRANSFORMER_SPECIAL_KEYS_REMAP = {} -VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {} +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "norm_out.linear.bias": lambda key, state_dict: state_dict.setdefault(key, torch.zeros(state_dict["norm_out.linear.weight"].shape[0])) +} +VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = { + "norm_out.linear.bias": lambda key, state_dict: state_dict.setdefault(key, torch.zeros(state_dict["norm_out.linear.weight"].shape[0])) +} def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: @@ -308,6 +312,10 @@ def convert_transformer(model_type: str): continue handler_fn_inplace(key, original_state_dict) + for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items(): + if special_key not in original_state_dict: + handler_fn_inplace(special_key, original_state_dict) + transformer.load_state_dict(original_state_dict, strict=True, assign=True) return transformer From 3178c4ebd286e1216c52173a04ef37b5efd77bbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 10 Jul 2025 17:50:16 +0300 Subject: [PATCH 08/21] add backwardability --- scripts/convert_wan_to_diffusers.py | 16 ++++---------- .../transformers/latte_transformer_3d.py | 9 ++++++++ .../transformers/pixart_transformer_2d.py | 15 ++++++++++--- .../transformers/transformer_allegro.py | 9 ++++++++ .../models/transformers/transformer_ltx.py | 21 +++++++++++++++++++ .../models/transformers/transformer_wan.py | 21 +++++++++++++++++++ .../transformers/transformer_wan_vace.py | 21 +++++++++++++++++++ 7 files changed, 97 insertions(+), 15 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 012982fe8f97..6d25cde071b1 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -25,7 +25,7 @@ "text_embedding.0": "condition_embedder.text_embedder.linear_1", "text_embedding.2": "condition_embedder.text_embedder.linear_2", "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "norm_out.linear.weight", + "head.modulation": "scale_shift_table", "head.head": "proj_out", "modulation": "scale_shift_table", "ffn.0": "ffn.net.0.proj", @@ -67,7 +67,7 @@ "text_embedding.0": "condition_embedder.text_embedder.linear_1", "text_embedding.2": "condition_embedder.text_embedder.linear_2", "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "norm_out.linear.weight", + "head.modulation": "scale_shift_table", "head.head": "proj_out", "modulation": "scale_shift_table", "ffn.0": "ffn.net.0.proj", @@ -105,12 +105,8 @@ "after_proj": "proj_out", } -TRANSFORMER_SPECIAL_KEYS_REMAP = { - "norm_out.linear.bias": lambda key, state_dict: state_dict.setdefault(key, torch.zeros(state_dict["norm_out.linear.weight"].shape[0])) -} -VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = { - "norm_out.linear.bias": lambda key, state_dict: state_dict.setdefault(key, torch.zeros(state_dict["norm_out.linear.weight"].shape[0])) -} +TRANSFORMER_SPECIAL_KEYS_REMAP = {} +VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {} def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: @@ -312,10 +308,6 @@ def convert_transformer(model_type: str): continue handler_fn_inplace(key, original_state_dict) - for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items(): - if special_key not in original_state_dict: - handler_fn_inplace(special_key, original_state_dict) - transformer.load_state_dict(original_state_dict, strict=True, assign=True) return transformer diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 486969047d6a..110ef4e7c086 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -171,6 +171,15 @@ def __init__( self.gradient_checkpointing = False + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + if "scale_shift_table" in state_dict: + scale_shift_table = state_dict.pop("scale_shift_table") + state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1] + state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0] + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index f05da448367d..ea8b10f01024 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -185,9 +185,18 @@ def __init__( ) self.caption_projection = None if self.config.caption_channels is not None: - self.caption_projection = PixArtAlphaTextProjection( - in_features=self.config.caption_channels, hidden_size=self.inner_dim - ) + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.config.caption_channels, hidden_size=self.inner_dim + ) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + if "scale_shift_table" in state_dict: + scale_shift_table = state_dict.pop("scale_shift_table") + state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1] + state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0] + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index a20c64801de5..85dc1f52d008 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -310,6 +310,15 @@ def __init__( self.gradient_checkpointing = False + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + if "scale_shift_table" in state_dict: + scale_shift_table = state_dict.pop("scale_shift_table") + state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1] + state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0] + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 91e40c8f3471..c25539088982 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -400,6 +400,27 @@ def __init__( self.gradient_checkpointing = False + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + key = "scale_shift_table" + if prefix + key in state_dict: + scale_shift_table = state_dict.pop(prefix + key) + inner_dim = scale_shift_table.shape[-1] + + weight = torch.eye(inner_dim).repeat(2, 1) + bias = scale_shift_table.reshape(2, inner_dim).flatten() + + state_dict[prefix + "norm_out.linear.weight"] = weight + state_dict[prefix + "norm_out.linear.bias"] = bias + + if prefix + "norm_out.weight" in state_dict: + state_dict.pop(prefix + "norm_out.weight") + if prefix + "norm_out.bias" in state_dict: + state_dict.pop(prefix + "norm_out.bias") + + return super(LTXVideoTransformer3DModel, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 505336ae502b..e61a6a319b61 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -439,6 +439,27 @@ def __init__( self.gradient_checkpointing = False + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + key = "scale_shift_table" + if prefix + key in state_dict: + scale_shift_table = state_dict.pop(prefix + key) + inner_dim = scale_shift_table.shape[-1] + + weight = torch.eye(inner_dim).repeat(2, 1) + bias = scale_shift_table.reshape(2, inner_dim).flatten() + + state_dict[prefix + "norm_out.linear.weight"] = weight + state_dict[prefix + "norm_out.linear.bias"] = bias + + if prefix + "norm_out.weight" in state_dict: + state_dict.pop(prefix + "norm_out.weight") + if prefix + "norm_out.bias" in state_dict: + state_dict.pop(prefix + "norm_out.bias") + + return super(WanTransformer3DModel, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 0813051cf5fc..44a40fb4b1d2 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -270,6 +270,27 @@ def __init__( self.gradient_checkpointing = False + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + key = "scale_shift_table" + if prefix + key in state_dict: + scale_shift_table = state_dict.pop(prefix + key) + inner_dim = scale_shift_table.shape[-1] + + weight = torch.eye(inner_dim).repeat(2, 1) + bias = scale_shift_table.reshape(2, inner_dim).flatten() + + state_dict[prefix + "norm_out.linear.weight"] = weight + state_dict[prefix + "norm_out.linear.bias"] = bias + + if prefix + "norm_out.weight" in state_dict: + state_dict.pop(prefix + "norm_out.weight") + if prefix + "norm_out.bias" in state_dict: + state_dict.pop(prefix + "norm_out.bias") + + return super(WanVACETransformer3DModel, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + def forward( self, hidden_states: torch.Tensor, From 65639d5101e55a7efa2b7552e9abcebaf084016d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 10 Jul 2025 17:50:55 +0300 Subject: [PATCH 09/21] style --- .../models/transformers/latte_transformer_3d.py | 4 +++- .../models/transformers/pixart_transformer_2d.py | 10 ++++++---- .../models/transformers/transformer_allegro.py | 4 +++- src/diffusers/models/transformers/transformer_ltx.py | 4 +++- src/diffusers/models/transformers/transformer_wan.py | 4 +++- .../models/transformers/transformer_wan_vace.py | 4 +++- 6 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 110ef4e7c086..03284bc2a624 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -171,7 +171,9 @@ def __init__( self.gradient_checkpointing = False - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): if "scale_shift_table" in state_dict: scale_shift_table = state_dict.pop("scale_shift_table") state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1] diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index ea8b10f01024..03fb1f8c30b8 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -185,11 +185,13 @@ def __init__( ) self.caption_projection = None if self.config.caption_channels is not None: - self.caption_projection = PixArtAlphaTextProjection( - in_features=self.config.caption_channels, hidden_size=self.inner_dim - ) + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.config.caption_channels, hidden_size=self.inner_dim + ) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): if "scale_shift_table" in state_dict: scale_shift_table = state_dict.pop("scale_shift_table") state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1] diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 85dc1f52d008..fae075985935 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -310,7 +310,9 @@ def __init__( self.gradient_checkpointing = False - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): if "scale_shift_table" in state_dict: scale_shift_table = state_dict.pop("scale_shift_table") state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1] diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index c25539088982..44c17f9fa8ee 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -400,7 +400,9 @@ def __init__( self.gradient_checkpointing = False - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): key = "scale_shift_table" if prefix + key in state_dict: scale_shift_table = state_dict.pop(prefix + key) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index e61a6a319b61..aea524491557 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -439,7 +439,9 @@ def __init__( self.gradient_checkpointing = False - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): key = "scale_shift_table" if prefix + key in state_dict: scale_shift_table = state_dict.pop(prefix + key) diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 44a40fb4b1d2..e25e75590829 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -270,7 +270,9 @@ def __init__( self.gradient_checkpointing = False - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): key = "scale_shift_table" if prefix + key in state_dict: scale_shift_table = state_dict.pop(prefix + key) From 51ecf1adb98a429c81d162982addf51aacdb270a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 16 Aug 2025 19:43:38 +0300 Subject: [PATCH 10/21] Adds option to disable SiLU in AdaLayerNorm --- src/diffusers/models/normalization.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index ae2a6298f5f7..22fd8c877443 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -36,6 +36,7 @@ class AdaLayerNorm(nn.Module): norm_elementwise_affine (`bool`, defaults to `False): norm_eps (`bool`, defaults to `False`): chunk_dim (`int`, defaults to `0`): + use_silu (`bool`, defaults to `True`): Whether to apply SiLU activation before the linear layer. """ def __init__( @@ -46,10 +47,12 @@ def __init__( norm_elementwise_affine: bool = False, norm_eps: float = 1e-5, chunk_dim: int = 0, + use_silu: bool = True, ): super().__init__() self.chunk_dim = chunk_dim + self.use_silu = use_silu output_dim = output_dim or embedding_dim * 2 if num_embeddings is not None: @@ -57,7 +60,8 @@ def __init__( else: self.emb = None - self.silu = nn.SiLU() + if self.use_silu: + self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, output_dim) self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) @@ -67,7 +71,9 @@ def forward( if self.emb is not None: temb = self.emb(timestep) - temb = self.linear(self.silu(temb)) + if self.use_silu: + temb = self.silu(temb) + temb = self.linear(temb) if self.chunk_dim == 1: # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the From 3022b10164e3cdada863153a1aa3519d39da0a09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 16 Aug 2025 19:46:00 +0300 Subject: [PATCH 11/21] Refactors output normalization in `SkyReelsV2Transformer3DModel` --- .../transformers/transformer_skyreels_v2.py | 30 ++++++++----------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 236fca690a90..919ea5d8c24b 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -33,7 +33,7 @@ ) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin, get_parameter_dtype -from ..normalization import FP32LayerNorm +from ..normalization import AdaLayerNorm, FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -386,7 +386,7 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] - _no_split_modules = ["SkyReelsV2TransformerBlock"] + _no_split_modules = ["SkyReelsV2TransformerBlock", "norm_out"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @@ -443,9 +443,15 @@ def __init__( ) # 4. Output norm & projection - self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.norm_out = AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=eps, + chunk_dim=1, + use_silu=False, + ) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) - self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) if inject_sample_info: self.fps_embedding = nn.Embedding(2, inner_dim) @@ -558,7 +564,6 @@ def forward( causal_mask, ) - if temb.dim() == 2: # If temb is 2D, we assume it has time 1-D time embedding values for each batch. # For models: # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers @@ -566,26 +571,15 @@ def forward( # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers - shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) - elif temb.dim() == 3: + # If temb is 3D, we assume it has 2-D time embedding values for each batch. # Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing. # For models: # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers - shift, scale = (self.scale_shift_table.unsqueeze(2) + temb.unsqueeze(1)).chunk(2, dim=1) - shift, scale = shift.squeeze(1), scale.squeeze(1) - - # Move the shift and scale tensors to the same device as hidden_states. - # When using multi-GPU inference via accelerate these will be on the - # first device rather than the last device, which hidden_states ends up - # on. - shift = shift.to(hidden_states.device) - scale = scale.to(hidden_states.device) - - hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.norm_out(hidden_states, temb=temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( From 35d8b3a3de4120eb8e111db6bdffc759284b22a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 17 Aug 2025 12:12:52 +0300 Subject: [PATCH 12/21] revert --- src/diffusers/models/normalization.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 22fd8c877443..9cdac46d369a 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -36,7 +36,6 @@ class AdaLayerNorm(nn.Module): norm_elementwise_affine (`bool`, defaults to `False): norm_eps (`bool`, defaults to `False`): chunk_dim (`int`, defaults to `0`): - use_silu (`bool`, defaults to `True`): Whether to apply SiLU activation before the linear layer. """ def __init__( @@ -47,12 +46,10 @@ def __init__( norm_elementwise_affine: bool = False, norm_eps: float = 1e-5, chunk_dim: int = 0, - use_silu: bool = True, ): super().__init__() self.chunk_dim = chunk_dim - self.use_silu = use_silu output_dim = output_dim or embedding_dim * 2 if num_embeddings is not None: @@ -60,8 +57,7 @@ def __init__( else: self.emb = None - if self.use_silu: - self.silu = nn.SiLU() + self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, output_dim) self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) @@ -71,8 +67,7 @@ def forward( if self.emb is not None: temb = self.emb(timestep) - if self.use_silu: - temb = self.silu(temb) + temb = self.silu(temb) temb = self.linear(temb) if self.chunk_dim == 1: From e0fe8377f4545306cd3087d4a971379307f54adc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 17 Aug 2025 18:04:35 +0300 Subject: [PATCH 13/21] Implement `SkyReelsV2AdaLayerNorm` for timestep embedding modulation and normalization --- .../transformers/transformer_skyreels_v2.py | 79 ++++++++++++++----- 1 file changed, 59 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 919ea5d8c24b..d54542b84e68 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -33,12 +33,68 @@ ) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin, get_parameter_dtype -from ..normalization import AdaLayerNorm, FP32LayerNorm +from ..normalization import FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class SkyReelsV2AdaLayerNorm(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + output_dim (`int`, *optional*): + norm_elementwise_affine (`bool`, defaults to `False): + norm_eps (`bool`, defaults to `False`): + """ + + def __init__( + self, + embedding_dim: int, + output_dim: Optional[int] = None, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + ): + super().__init__() + + output_dim = output_dim or embedding_dim * 2 + + self.linear = nn.Linear(embedding_dim, output_dim) + self.linear.weight.data[:embedding_dim, :] = torch.eye(embedding_dim) + self.linear.weight.data[embedding_dim:, :] = torch.eye(embedding_dim) + self.norm = FP32LayerNorm(embedding_dim, norm_eps, norm_elementwise_affine) + + def forward( + self, x: torch.Tensor, temb: torch.Tensor + ) -> torch.Tensor: + if temb.ndim == 2: + # If temb is 2D, we assume it has 1-D time embedding values for each batch. + # For models: + # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers + # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers + # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers + # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers + # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers + # 2D temb: (batch, embedding_dim) + temb = self.linear(temb.unsqueeze(1)) # (batch, 1, embedding_dim * 2) + shift, scale = temb.chunk(2, dim=2) + elif temb.ndim == 3: + # If temb is 3D, we assume it has 2-D time embedding values for each batch. + # Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing. + # For models: + # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers + # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers + # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers + # 3D temb: (batch, num_latent_frames * post_patch_height * post_patch_width, embedding_dim) + temb = self.linear(temb) # (batch, num_latent_frames * post_patch_height * post_patch_width, embedding_dim * 2) + shift, scale = temb.chunk(2, dim=2) + + x = self.norm(x) * (1 + scale) + shift + return x + + class SkyReelsV2AttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -386,7 +442,7 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] - _no_split_modules = ["SkyReelsV2TransformerBlock", "norm_out"] + _no_split_modules = ["SkyReelsV2TransformerBlock", "SkyReelsV2AdaLayerNorm"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @@ -443,13 +499,11 @@ def __init__( ) # 4. Output norm & projection - self.norm_out = AdaLayerNorm( + self.norm_out = SkyReelsV2AdaLayerNorm( embedding_dim=inner_dim, output_dim=2 * inner_dim, norm_elementwise_affine=False, norm_eps=eps, - chunk_dim=1, - use_silu=False, ) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) @@ -564,21 +618,6 @@ def forward( causal_mask, ) - # If temb is 2D, we assume it has time 1-D time embedding values for each batch. - # For models: - # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers - # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers - # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers - # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers - # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers - - # If temb is 3D, we assume it has 2-D time embedding values for each batch. - # Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing. - # For models: - # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers - # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers - # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers - hidden_states = self.norm_out(hidden_states, temb=temb) hidden_states = self.proj_out(hidden_states) From 1f649b41a33f7e131298c5de66e700a98cdc2747 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 17 Aug 2025 18:09:06 +0300 Subject: [PATCH 14/21] style --- .../models/transformers/transformer_skyreels_v2.py | 8 ++++---- src/diffusers/models/transformers/transformer_wan_vace.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index d54542b84e68..d9130f4c329a 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -66,9 +66,7 @@ def __init__( self.linear.weight.data[embedding_dim:, :] = torch.eye(embedding_dim) self.norm = FP32LayerNorm(embedding_dim, norm_eps, norm_elementwise_affine) - def forward( - self, x: torch.Tensor, temb: torch.Tensor - ) -> torch.Tensor: + def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: if temb.ndim == 2: # If temb is 2D, we assume it has 1-D time embedding values for each batch. # For models: @@ -88,7 +86,9 @@ def forward( # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers # 3D temb: (batch, num_latent_frames * post_patch_height * post_patch_width, embedding_dim) - temb = self.linear(temb) # (batch, num_latent_frames * post_patch_height * post_patch_width, embedding_dim * 2) + temb = self.linear( + temb + ) # (batch, num_latent_frames * post_patch_height * post_patch_width, embedding_dim * 2) shift, scale = temb.chunk(2, dim=2) x = self.norm(x) * (1 + scale) + shift diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index b835251a7c7a..a76be1da99b3 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -25,7 +25,7 @@ from ..cache_utils import CacheMixin from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import FP32LayerNorm +from ..normalization import AdaLayerNorm, FP32LayerNorm from .transformer_wan import ( WanAttention, WanAttnProcessor, From 7b2b0f4e32f804d5a6f62f1585031a59e69c204c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 17 Aug 2025 18:46:26 +0300 Subject: [PATCH 15/21] Enhances backward compatibility by converting deprecated `scale_shift_table` to `SkyReelsV2AdaLayerNorm` format in `_load_from_state_dict` --- .../transformers/transformer_skyreels_v2.py | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index d9130f4c329a..5bbaa758f795 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -444,7 +444,7 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["SkyReelsV2TransformerBlock", "SkyReelsV2AdaLayerNorm"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] - _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q", "scale_shift_table"] @register_to_config def __init__( @@ -513,6 +513,46 @@ def __init__( self.gradient_checkpointing = False + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + """ + Handle backward compatibility by converting deprecated `scale_shift_table` to the new `SkyReelsV2AdaLayerNorm` format. + """ + # Check if this is an old checkpoint with scale_shift_table + scale_shift_table_key = prefix + "scale_shift_table" + if scale_shift_table_key in state_dict: + scale_shift_table = state_dict.pop(scale_shift_table_key) + + # The scale_shift_table has shape (1, 2, inner_dim) + inner_dim = scale_shift_table.shape[2] + + # Create identity matrices for the linear transformation + # This maintains the original behavior where the linear layer acts as input.dot(identity) + scale_shift_table + identity_matrix = torch.eye(inner_dim) + linear_weight = torch.cat([identity_matrix, identity_matrix], dim=0) + + # Set the linear layer weights and bias + state_dict[prefix + "norm_out.linear.weight"] = linear_weight.T + # The bias should contain the original scale_shift_table values + # scale_shift_table shape: (1, 2, inner_dim) -> flatten to (2 * inner_dim,) + state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table.flatten() + + # Handle FP32LayerNorm parameter renaming + old_norm_weight_key = prefix + "norm_out.weight" + if old_norm_weight_key in state_dict: + state_dict[prefix + "norm_out.norm.weight"] = state_dict.pop(old_norm_weight_key) + + old_norm_bias_key = prefix + "norm_out.bias" + if old_norm_bias_key in state_dict: + state_dict[prefix + "norm_out.norm.bias"] = state_dict.pop(old_norm_bias_key) + + logger.info("Converted deprecated 'scale_shift_table' to new 'SkyReelsV2AdaLayerNorm' format for backward compatibility.") + + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + def forward( self, hidden_states: torch.Tensor, From 374530f59c65c4a2f2eabc306dab6b5f288937d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 17 Aug 2025 19:24:11 +0300 Subject: [PATCH 16/21] Adds missing keys to ignore on load and adjusts identity matrix creation for device and dtype consistency --- .../models/transformers/transformer_skyreels_v2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 5bbaa758f795..969f6eaa6373 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -445,6 +445,7 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr _no_split_modules = ["SkyReelsV2TransformerBlock", "SkyReelsV2AdaLayerNorm"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q", "scale_shift_table"] + _keys_to_ignore_on_load_missing = ["norm_out.linear.weight", "norm_out.linear.bias"] @register_to_config def __init__( @@ -529,7 +530,8 @@ def _load_from_state_dict( # Create identity matrices for the linear transformation # This maintains the original behavior where the linear layer acts as input.dot(identity) + scale_shift_table - identity_matrix = torch.eye(inner_dim) + # Use same device and dtype as the original scale_shift_table to avoid meta tensor issues + identity_matrix = torch.eye(inner_dim, device=scale_shift_table.device, dtype=scale_shift_table.dtype) linear_weight = torch.cat([identity_matrix, identity_matrix], dim=0) # Set the linear layer weights and bias @@ -538,7 +540,7 @@ def _load_from_state_dict( # scale_shift_table shape: (1, 2, inner_dim) -> flatten to (2 * inner_dim,) state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table.flatten() - # Handle FP32LayerNorm parameter renaming + # Handle FP32LayerNorm parameter renaming: norm_out -> norm_out.norm old_norm_weight_key = prefix + "norm_out.weight" if old_norm_weight_key in state_dict: state_dict[prefix + "norm_out.norm.weight"] = state_dict.pop(old_norm_weight_key) From 9c7ac6e0d7b4f93065803f2d514d117feb98bf3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 17 Aug 2025 20:00:10 +0300 Subject: [PATCH 17/21] Fix device and dtype handling for identity matrix and scale_shift_table in SkyReelsV2Transformer3DModel --- .../models/transformers/transformer_skyreels_v2.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 969f6eaa6373..dc3c9b4a8289 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -530,15 +530,20 @@ def _load_from_state_dict( # Create identity matrices for the linear transformation # This maintains the original behavior where the linear layer acts as input.dot(identity) + scale_shift_table - # Use same device and dtype as the original scale_shift_table to avoid meta tensor issues - identity_matrix = torch.eye(inner_dim, device=scale_shift_table.device, dtype=scale_shift_table.dtype) + # If scale_shift_table is on meta device, create tensors on CPU to avoid meta tensor issues + device = scale_shift_table.device if scale_shift_table.device.type != "meta" else torch.device("cpu") + dtype = scale_shift_table.dtype + + identity_matrix = torch.eye(inner_dim, device=device, dtype=dtype) linear_weight = torch.cat([identity_matrix, identity_matrix], dim=0) # Set the linear layer weights and bias state_dict[prefix + "norm_out.linear.weight"] = linear_weight.T # The bias should contain the original scale_shift_table values # scale_shift_table shape: (1, 2, inner_dim) -> flatten to (2 * inner_dim,) - state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table.flatten() + # Move scale_shift_table to same device as the identity matrix to avoid device mismatch + scale_shift_table_flat = scale_shift_table.to(device).flatten() + state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table_flat # Handle FP32LayerNorm parameter renaming: norm_out -> norm_out.norm old_norm_weight_key = prefix + "norm_out.weight" From df9efaec266bb41dae8e93729a2c35716600d5bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Sun, 17 Aug 2025 20:45:17 +0300 Subject: [PATCH 18/21] Update normalization.py --- src/diffusers/models/normalization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 9cdac46d369a..ae2a6298f5f7 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -67,8 +67,7 @@ def forward( if self.emb is not None: temb = self.emb(timestep) - temb = self.silu(temb) - temb = self.linear(temb) + temb = self.linear(self.silu(temb)) if self.chunk_dim == 1: # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the From b27c0bf223f93e7d4afc4dbb1e6364399e369065 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Mon, 18 Aug 2025 10:51:39 +0300 Subject: [PATCH 19/21] Delete src/diffusers/models/transformers/latte_transformer_3d.py --- .../transformers/latte_transformer_3d.py | 345 ------------------ 1 file changed, 345 deletions(-) delete mode 100644 src/diffusers/models/transformers/latte_transformer_3d.py diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py deleted file mode 100644 index 03284bc2a624..000000000000 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ /dev/null @@ -1,345 +0,0 @@ -# Copyright 2025 the Latte Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import torch -from torch import nn - -from ...configuration_utils import ConfigMixin, register_to_config -from ..attention import BasicTransformerBlock -from ..cache_utils import CacheMixin -from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid -from ..modeling_outputs import Transformer2DModelOutput -from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNorm, AdaLayerNormSingle - - -class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): - _supports_gradient_checkpointing = True - _no_split_modules = ["norm_out"] - - """ - A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code: - https://github.com/Vchitect/Latte - - Parameters: - num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. - attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. - in_channels (`int`, *optional*): - The number of channels in the input. - out_channels (`int`, *optional*): - The number of channels in the output. - num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - attention_bias (`bool`, *optional*): - Configure if the `TransformerBlocks` attention should contain a bias parameter. - sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). - This is fixed during training since it is used to learn a number of position embeddings. - patch_size (`int`, *optional*): - The size of the patches to use in the patch embedding layer. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): - The number of diffusion steps used during training. Pass if at least one of the norm_layers is - `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are - added to the hidden states. During inference, you can denoise for up to but not more steps than - `num_embeds_ada_norm`. - norm_type (`str`, *optional*, defaults to `"layer_norm"`): - The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. - norm_elementwise_affine (`bool`, *optional*, defaults to `True`): - Whether or not to use elementwise affine in normalization layers. - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. - caption_channels (`int`, *optional*): - The number of channels in the caption embeddings. - video_length (`int`, *optional*): - The number of frames in the video-like data. - """ - - _skip_layerwise_casting_patterns = ["pos_embed", "norm"] - - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - out_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - sample_size: int = 64, - patch_size: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - norm_type: str = "layer_norm", - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - caption_channels: int = None, - video_length: int = 16, - ): - super().__init__() - inner_dim = num_attention_heads * attention_head_dim - - # 1. Define input layers - self.height = sample_size - self.width = sample_size - - interpolation_scale = self.config.sample_size // 64 - interpolation_scale = max(interpolation_scale, 1) - self.pos_embed = PatchEmbed( - height=sample_size, - width=sample_size, - patch_size=patch_size, - in_channels=in_channels, - embed_dim=inner_dim, - interpolation_scale=interpolation_scale, - ) - - # 2. Define spatial transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - ) - for d in range(num_layers) - ] - ) - - # 3. Define temporal transformers blocks - self.temporal_transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=None, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - norm_type=norm_type, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - self.out_channels = in_channels if out_channels is None else out_channels - self.norm_out = AdaLayerNorm( - embedding_dim=inner_dim, - output_dim=2 * inner_dim, - norm_elementwise_affine=False, - norm_eps=1e-6, - chunk_dim=1, - ) - self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) - - # 5. Latte other blocks. - self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) - self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) - - # define temporal positional embedding - temp_pos_embed = get_1d_sincos_pos_embed_from_grid( - inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt" - ) # 1152 hidden size - self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False) - - self.gradient_checkpointing = False - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - if "scale_shift_table" in state_dict: - scale_shift_table = state_dict.pop("scale_shift_table") - state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1] - state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0] - return super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - - def forward( - self, - hidden_states: torch.Tensor, - timestep: Optional[torch.LongTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - enable_temporal_attentions: bool = True, - return_dict: bool = True, - ): - """ - The [`LatteTransformer3DModel`] forward method. - - Args: - hidden_states shape `(batch size, channel, num_frame, height, width)`: - Input `hidden_states`. - timestep ( `torch.LongTensor`, *optional*): - Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. - encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): - Conditional embeddings for cross attention layer. If not given, cross-attention defaults to - self-attention. - encoder_attention_mask ( `torch.Tensor`, *optional*): - Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: - - * Mask `(batcheight, sequence_length)` True = keep, False = discard. - * Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard. - - If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format - above. This bias will be added to the cross-attention scores. - enable_temporal_attentions: - (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - - Returns: - If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a - `tuple` where the first element is the sample tensor. - """ - - # Reshape hidden states - batch_size, channels, num_frame, height, width = hidden_states.shape - # batch_size channels num_frame height width -> (batch_size * num_frame) channels height width - hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width) - - # Input - height, width = ( - hidden_states.shape[-2] // self.config.patch_size, - hidden_states.shape[-1] // self.config.patch_size, - ) - num_patches = height * width - - hidden_states = self.pos_embed(hidden_states) # already add positional embeddings - - added_cond_kwargs = {"resolution": None, "aspect_ratio": None} - timestep, embedded_timestep = self.adaln_single( - timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype - ) - - # Prepare text embeddings for spatial block - # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size - encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 - encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave( - num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame - ).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]) - - # Prepare timesteps for spatial and temporal block - timestep_spatial = timestep.repeat_interleave( - num_frame, dim=0, output_size=timestep.shape[0] * num_frame - ).view(-1, timestep.shape[-1]) - timestep_temp = timestep.repeat_interleave( - num_patches, dim=0, output_size=timestep.shape[0] * num_patches - ).view(-1, timestep.shape[-1]) - - # Spatial and temporal transformer blocks - for i, (spatial_block, temp_block) in enumerate( - zip(self.transformer_blocks, self.temporal_transformer_blocks) - ): - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( - spatial_block, - hidden_states, - None, # attention_mask - encoder_hidden_states_spatial, - encoder_attention_mask, - timestep_spatial, - None, # cross_attention_kwargs - None, # class_labels - ) - else: - hidden_states = spatial_block( - hidden_states, - None, # attention_mask - encoder_hidden_states_spatial, - encoder_attention_mask, - timestep_spatial, - None, # cross_attention_kwargs - None, # class_labels - ) - - if enable_temporal_attentions: - # (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size - hidden_states = hidden_states.reshape( - batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1] - ).permute(0, 2, 1, 3) - hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) - - if i == 0 and num_frame > 1: - hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype) - - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( - temp_block, - hidden_states, - None, # attention_mask - None, # encoder_hidden_states - None, # encoder_attention_mask - timestep_temp, - None, # cross_attention_kwargs - None, # class_labels - ) - else: - hidden_states = temp_block( - hidden_states, - None, # attention_mask - None, # encoder_hidden_states - None, # encoder_attention_mask - timestep_temp, - None, # cross_attention_kwargs - None, # class_labels - ) - - # (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size - hidden_states = hidden_states.reshape( - batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1] - ).permute(0, 2, 1, 3) - hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) - - embedded_timestep = embedded_timestep.repeat_interleave( - num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame - ).view(-1, embedded_timestep.shape[-1]) - hidden_states = self.norm_out(hidden_states, temb=embedded_timestep) - hidden_states = self.proj_out(hidden_states) - - # unpatchify - if self.adaln_single is None: - height = width = int(hidden_states.shape[1] ** 0.5) - hidden_states = hidden_states.reshape( - shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels) - ) - hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) - output = hidden_states.reshape( - shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size) - ) - output = output.reshape(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]).permute( - 0, 2, 1, 3, 4 - ) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) From 05a8ce11d88929524b401aac018c83901bdf1b73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 18 Aug 2025 10:59:03 +0300 Subject: [PATCH 20/21] Revert currently unrelated ones --- .../transformers/latte_transformer_3d.py | 331 ++++++++++++++++++ .../transformers/pixart_transformer_2d.py | 31 +- .../transformers/transformer_allegro.py | 29 +- .../models/transformers/transformer_ltx.py | 41 +-- .../models/transformers/transformer_wan.py | 54 ++- .../transformers/transformer_wan_vace.py | 47 +-- 6 files changed, 392 insertions(+), 141 deletions(-) create mode 100644 src/diffusers/models/transformers/latte_transformer_3d.py diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py new file mode 100644 index 000000000000..990c90512e39 --- /dev/null +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -0,0 +1,331 @@ +# Copyright 2025 the Latte Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ..attention import BasicTransformerBlock +from ..cache_utils import CacheMixin +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle + + +class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): + _supports_gradient_checkpointing = True + + """ + A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code: + https://github.com/Vchitect/Latte + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input. + out_channels (`int`, *optional*): + The number of channels in the output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + patch_size (`int`, *optional*): + The size of the patches to use in the patch embedding layer. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. During inference, you can denoise for up to but not more steps than + `num_embeds_ada_norm`. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. + caption_channels (`int`, *optional*): + The number of channels in the caption embeddings. + video_length (`int`, *optional*): + The number of frames in the video-like data. + """ + + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: int = 64, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + caption_channels: int = None, + video_length: int = 16, + ): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + + # 1. Define input layers + self.height = sample_size + self.width = sample_size + + interpolation_scale = self.config.sample_size // 64 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 2. Define spatial transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for d in range(num_layers) + ] + ) + + # 3. Define temporal transformers blocks + self.temporal_transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=None, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. Latte other blocks. + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + # define temporal positional embedding + temp_pos_embed = get_1d_sincos_pos_embed_from_grid( + inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt" + ) # 1152 hidden size + self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + enable_temporal_attentions: bool = True, + return_dict: bool = True, + ): + """ + The [`LatteTransformer3DModel`] forward method. + + Args: + hidden_states shape `(batch size, channel, num_frame, height, width)`: + Input `hidden_states`. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batcheight, sequence_length)` True = keep, False = discard. + * Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + enable_temporal_attentions: + (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + + # Reshape hidden states + batch_size, channels, num_frame, height, width = hidden_states.shape + # batch_size channels num_frame height width -> (batch_size * num_frame) channels height width + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width) + + # Input + height, width = ( + hidden_states.shape[-2] // self.config.patch_size, + hidden_states.shape[-1] // self.config.patch_size, + ) + num_patches = height * width + + hidden_states = self.pos_embed(hidden_states) # already add positional embeddings + + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # Prepare text embeddings for spatial block + # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size + encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152 + encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave( + num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame + ).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]) + + # Prepare timesteps for spatial and temporal block + timestep_spatial = timestep.repeat_interleave( + num_frame, dim=0, output_size=timestep.shape[0] * num_frame + ).view(-1, timestep.shape[-1]) + timestep_temp = timestep.repeat_interleave( + num_patches, dim=0, output_size=timestep.shape[0] * num_patches + ).view(-1, timestep.shape[-1]) + + # Spatial and temporal transformer blocks + for i, (spatial_block, temp_block) in enumerate( + zip(self.transformer_blocks, self.temporal_transformer_blocks) + ): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + spatial_block, + hidden_states, + None, # attention_mask + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + None, # cross_attention_kwargs + None, # class_labels + ) + else: + hidden_states = spatial_block( + hidden_states, + None, # attention_mask + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + None, # cross_attention_kwargs + None, # class_labels + ) + + if enable_temporal_attentions: + # (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size + hidden_states = hidden_states.reshape( + batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1] + ).permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) + + if i == 0 and num_frame > 1: + hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + temp_block, + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + None, # cross_attention_kwargs + None, # class_labels + ) + else: + hidden_states = temp_block( + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + None, # cross_attention_kwargs + None, # class_labels + ) + + # (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size + hidden_states = hidden_states.reshape( + batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1] + ).permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1]) + + embedded_timestep = embedded_timestep.repeat_interleave( + num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame + ).view(-1, embedded_timestep.shape[-1]) + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size) + ) + output = output.reshape(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]).permute( + 0, 2, 1, 3, 4 + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 03fb1f8c30b8..40a14bfd9b27 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -23,7 +23,7 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNorm, AdaLayerNormSingle +from ..normalization import AdaLayerNormSingle logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -78,7 +78,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _no_split_modules = ["BasicTransformerBlock", "PatchEmbed", "norm_out"] + _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"] @register_to_config @@ -171,13 +171,8 @@ def __init__( ) # 3. Output blocks. - self.norm_out = AdaLayerNorm( - embedding_dim=self.inner_dim, - output_dim=2 * self.inner_dim, - norm_elementwise_affine=False, - norm_eps=1e-6, - chunk_dim=1, - ) + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels) self.adaln_single = AdaLayerNormSingle( @@ -189,17 +184,6 @@ def __init__( in_features=self.config.caption_channels, hidden_size=self.inner_dim ) - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - if "scale_shift_table" in state_dict: - scale_shift_table = state_dict.pop("scale_shift_table") - state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1] - state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0] - return super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -422,7 +406,12 @@ def forward( ) # 3. Output - hidden_states = self.norm_out(hidden_states, temb=embedded_timestep) + shift, scale = ( + self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device) + ).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index fae075985935..5fa59a71d977 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -28,7 +28,7 @@ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNorm, AdaLayerNormSingle +from ..normalization import AdaLayerNormSingle logger = logging.get_logger(__name__) @@ -175,7 +175,6 @@ def forward( class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): _supports_gradient_checkpointing = True - _no_split_modules = ["norm_out"] """ A 3D Transformer model for video-like data. @@ -293,13 +292,8 @@ def __init__( ) # 3. Output projection & norm - self.norm_out = AdaLayerNorm( - embedding_dim=self.inner_dim, - output_dim=2 * self.inner_dim, - norm_elementwise_affine=False, - norm_eps=1e-6, - chunk_dim=1, - ) + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels) # 4. Timestep embeddings @@ -310,17 +304,6 @@ def __init__( self.gradient_checkpointing = False - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - if "scale_shift_table" in state_dict: - scale_shift_table = state_dict.pop("scale_shift_table") - state_dict[prefix + "norm_out.linear.weight"] = scale_shift_table[1] - state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table[0] - return super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - def forward( self, hidden_states: torch.Tensor, @@ -410,7 +393,11 @@ def forward( ) # 4. Output normalization & projection - hidden_states = self.norm_out(hidden_states, temb=embedded_timestep) + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + + # Modulation + hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 6c1f8b9758c8..79149fb76067 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -30,7 +30,7 @@ from ..embeddings import PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNorm, AdaLayerNormSingle, RMSNorm +from ..normalization import AdaLayerNormSingle, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -408,7 +408,6 @@ class LTXVideoTransformer3DModel( _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] - _no_split_modules = ["norm_out"] _repeated_blocks = ["LTXVideoTransformerBlock"] @register_to_config @@ -437,6 +436,7 @@ def __init__( self.proj_in = nn.Linear(in_channels, inner_dim) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) @@ -469,40 +469,11 @@ def __init__( ] ) - self.norm_out = AdaLayerNorm( - embedding_dim=inner_dim, - output_dim=2 * inner_dim, - norm_elementwise_affine=False, - norm_eps=1e-6, - chunk_dim=1, - ) + self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, out_channels) self.gradient_checkpointing = False - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - key = "scale_shift_table" - if prefix + key in state_dict: - scale_shift_table = state_dict.pop(prefix + key) - inner_dim = scale_shift_table.shape[-1] - - weight = torch.eye(inner_dim).repeat(2, 1) - bias = scale_shift_table.reshape(2, inner_dim).flatten() - - state_dict[prefix + "norm_out.linear.weight"] = weight - state_dict[prefix + "norm_out.linear.bias"] = bias - - if prefix + "norm_out.weight" in state_dict: - state_dict.pop(prefix + "norm_out.weight") - if prefix + "norm_out.bias" in state_dict: - state_dict.pop(prefix + "norm_out.bias") - - return super(LTXVideoTransformer3DModel, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - def forward( self, hidden_states: torch.Tensor, @@ -573,7 +544,11 @@ def forward( encoder_attention_mask=encoder_attention_mask, ) - hidden_states = self.norm_out(hidden_states, temb=embedded_timestep.squeeze(1)) + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift output = self.proj_out(hidden_states) if USE_PEFT_BACKEND: diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 16ba30863bfc..968a0369c243 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -29,7 +29,7 @@ from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNorm, FP32LayerNorm +from ..normalization import FP32LayerNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -535,7 +535,7 @@ class WanTransformer3DModel( _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] - _no_split_modules = ["WanTransformerBlock", "norm_out"] + _no_split_modules = ["WanTransformerBlock"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["WanTransformerBlock"] @@ -591,40 +591,12 @@ def __init__( ) # 4. Output norm & projection - self.norm_out = AdaLayerNorm( - embedding_dim=inner_dim, - output_dim=2 * inner_dim, - norm_elementwise_affine=False, - norm_eps=eps, - chunk_dim=1, - ) + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - key = "scale_shift_table" - if prefix + key in state_dict: - scale_shift_table = state_dict.pop(prefix + key) - inner_dim = scale_shift_table.shape[-1] - - weight = torch.eye(inner_dim).repeat(2, 1) - bias = scale_shift_table.reshape(2, inner_dim).flatten() - - state_dict[prefix + "norm_out.linear.weight"] = weight - state_dict[prefix + "norm_out.linear.bias"] = bias - - if prefix + "norm_out.weight" in state_dict: - state_dict.pop(prefix + "norm_out.weight") - if prefix + "norm_out.bias" in state_dict: - state_dict.pop(prefix + "norm_out.bias") - - return super(WanTransformer3DModel, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - def forward( self, hidden_states: torch.Tensor, @@ -691,7 +663,23 @@ def forward( hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) # 5. Output norm, projection & unpatchify - hidden_states = self.norm_out(hidden_states, temb=temb) + if temb.ndim == 3: + # batch_size, seq_len, inner_dim (wan 2.2 ti2v) + shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # batch_size, inner_dim + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index a76be1da99b3..e039d362193d 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -25,7 +25,7 @@ from ..cache_utils import CacheMixin from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNorm, FP32LayerNorm +from ..normalization import FP32LayerNorm from .transformer_wan import ( WanAttention, WanAttnProcessor, @@ -173,7 +173,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"] - _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock", "norm_out"] + _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @@ -253,40 +253,12 @@ def __init__( ) # 4. Output norm & projection - self.norm_out = AdaLayerNorm( - embedding_dim=inner_dim, - output_dim=2 * inner_dim, - norm_elementwise_affine=False, - norm_eps=eps, - chunk_dim=1, - ) + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - key = "scale_shift_table" - if prefix + key in state_dict: - scale_shift_table = state_dict.pop(prefix + key) - inner_dim = scale_shift_table.shape[-1] - - weight = torch.eye(inner_dim).repeat(2, 1) - bias = scale_shift_table.reshape(2, inner_dim).flatten() - - state_dict[prefix + "norm_out.linear.weight"] = weight - state_dict[prefix + "norm_out.linear.bias"] = bias - - if prefix + "norm_out.weight" in state_dict: - state_dict.pop(prefix + "norm_out.weight") - if prefix + "norm_out.bias" in state_dict: - state_dict.pop(prefix + "norm_out.bias") - - return super(WanVACETransformer3DModel, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - def forward( self, hidden_states: torch.Tensor, @@ -387,7 +359,16 @@ def forward( hidden_states = hidden_states + control_hint * scale # 6. Output norm, projection & unpatchify - hidden_states = self.norm_out(hidden_states, temb=temb) + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( From 6bbf350bdea2c678751cb2c803e2db5caf7e35b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 19 Aug 2025 20:59:41 +0300 Subject: [PATCH 21/21] Remove `_load_from_state_dict` --- .../transformers/transformer_skyreels_v2.py | 46 ------------------- 1 file changed, 46 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index dc3c9b4a8289..b17e30ce2717 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -514,52 +514,6 @@ def __init__( self.gradient_checkpointing = False - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - """ - Handle backward compatibility by converting deprecated `scale_shift_table` to the new `SkyReelsV2AdaLayerNorm` format. - """ - # Check if this is an old checkpoint with scale_shift_table - scale_shift_table_key = prefix + "scale_shift_table" - if scale_shift_table_key in state_dict: - scale_shift_table = state_dict.pop(scale_shift_table_key) - - # The scale_shift_table has shape (1, 2, inner_dim) - inner_dim = scale_shift_table.shape[2] - - # Create identity matrices for the linear transformation - # This maintains the original behavior where the linear layer acts as input.dot(identity) + scale_shift_table - # If scale_shift_table is on meta device, create tensors on CPU to avoid meta tensor issues - device = scale_shift_table.device if scale_shift_table.device.type != "meta" else torch.device("cpu") - dtype = scale_shift_table.dtype - - identity_matrix = torch.eye(inner_dim, device=device, dtype=dtype) - linear_weight = torch.cat([identity_matrix, identity_matrix], dim=0) - - # Set the linear layer weights and bias - state_dict[prefix + "norm_out.linear.weight"] = linear_weight.T - # The bias should contain the original scale_shift_table values - # scale_shift_table shape: (1, 2, inner_dim) -> flatten to (2 * inner_dim,) - # Move scale_shift_table to same device as the identity matrix to avoid device mismatch - scale_shift_table_flat = scale_shift_table.to(device).flatten() - state_dict[prefix + "norm_out.linear.bias"] = scale_shift_table_flat - - # Handle FP32LayerNorm parameter renaming: norm_out -> norm_out.norm - old_norm_weight_key = prefix + "norm_out.weight" - if old_norm_weight_key in state_dict: - state_dict[prefix + "norm_out.norm.weight"] = state_dict.pop(old_norm_weight_key) - - old_norm_bias_key = prefix + "norm_out.bias" - if old_norm_bias_key in state_dict: - state_dict[prefix + "norm_out.norm.bias"] = state_dict.pop(old_norm_bias_key) - - logger.info("Converted deprecated 'scale_shift_table' to new 'SkyReelsV2AdaLayerNorm' format for backward compatibility.") - - return super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) - def forward( self, hidden_states: torch.Tensor,