Skip to content
36 changes: 33 additions & 3 deletions src/diffusers/models/transformers/transformer_cogview4.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous

from ..normalization import LayerNorm, RMSNorm

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -584,6 +583,37 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
return (freqs.cos(), freqs.sin())


class _CogViewFinalAdaLayerNormContinuous(nn.Module):
"""
CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine.
Matches Megatron: **no activation** before the Linear on conditioning embedding.
"""
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
norm_type: str = "layer_norm",
):
super().__init__()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")

def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
# *** NO SiLU here ***
emb = self.linear(conditioning_embedding.to(x.dtype))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x


class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
r"""
Args:
Expand Down Expand Up @@ -666,7 +696,7 @@ def __init__(
)

# 4. Output projection
self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
self.norm_out = _CogViewFinalAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)

self.gradient_checkpointing = False
Expand Down