Skip to content

Commit bbafd53

Browse files
武嘉涵武嘉涵
authored andcommitted
CogView4: use local final AdaLN (no SiLU) per review; keep generic AdaLN unchanged
1 parent 39294e3 commit bbafd53

File tree

1 file changed

+33
-3
lines changed

1 file changed

+33
-3
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
2929
from ..modeling_outputs import Transformer2DModelOutput
3030
from ..modeling_utils import ModelMixin
31-
from ..normalization import AdaLayerNormContinuous
32-
31+
from ..normalization import LayerNorm, RMSNorm
3332

3433
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3534

@@ -584,6 +583,37 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
584583
return (freqs.cos(), freqs.sin())
585584

586585

586+
class _CogViewFinalAdaLayerNormContinuous(nn.Module):
587+
"""
588+
CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine.
589+
Matches Megatron: **no activation** before the Linear on conditioning embedding.
590+
"""
591+
def __init__(
592+
self,
593+
embedding_dim: int,
594+
conditioning_embedding_dim: int,
595+
elementwise_affine: bool = True,
596+
eps: float = 1e-5,
597+
bias: bool = True,
598+
norm_type: str = "layer_norm",
599+
):
600+
super().__init__()
601+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
602+
if norm_type == "layer_norm":
603+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
604+
elif norm_type == "rms_norm":
605+
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
606+
else:
607+
raise ValueError(f"unknown norm_type {norm_type}")
608+
609+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
610+
# *** NO SiLU here ***
611+
emb = self.linear(conditioning_embedding.to(x.dtype))
612+
scale, shift = torch.chunk(emb, 2, dim=1)
613+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
614+
return x
615+
616+
587617
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
588618
r"""
589619
Args:
@@ -666,7 +696,7 @@ def __init__(
666696
)
667697

668698
# 4. Output projection
669-
self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False, use_silu=False)
699+
self.norm_out = _CogViewFinalAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
670700
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
671701

672702
self.gradient_checkpointing = False

0 commit comments

Comments
 (0)