diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index dc45befb98fa..25dcfa14cc0b 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -28,7 +28,7 @@ from ..embeddings import CogView3CombinedTimestepSizeEmbeddings from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous +from ..normalization import LayerNorm, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -584,6 +584,38 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens return (freqs.cos(), freqs.sin()) +class CogView4AdaLayerNormContinuous(nn.Module): + """ + CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the + Linear on conditioning embedding. + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + norm_type: str = "layer_norm", + ): + super().__init__() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # *** NO SiLU here *** + emb = self.linear(conditioning_embedding.to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): r""" Args: @@ -666,7 +698,7 @@ def __init__( ) # 4. Output projection - self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) + self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True) self.gradient_checkpointing = False