|
28 | 28 | from ..embeddings import CogView3CombinedTimestepSizeEmbeddings |
29 | 29 | from ..modeling_outputs import Transformer2DModelOutput |
30 | 30 | from ..modeling_utils import ModelMixin |
31 | | -from ..normalization import AdaLayerNormContinuous |
32 | | - |
| 31 | +from ..normalization import LayerNorm, RMSNorm |
33 | 32 |
|
34 | 33 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
35 | 34 |
|
@@ -584,6 +583,37 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens |
584 | 583 | return (freqs.cos(), freqs.sin()) |
585 | 584 |
|
586 | 585 |
|
| 586 | +class _CogViewFinalAdaLayerNormContinuous(nn.Module): |
| 587 | + """ |
| 588 | + CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. |
| 589 | + Matches Megatron: **no activation** before the Linear on conditioning embedding. |
| 590 | + """ |
| 591 | + def __init__( |
| 592 | + self, |
| 593 | + embedding_dim: int, |
| 594 | + conditioning_embedding_dim: int, |
| 595 | + elementwise_affine: bool = True, |
| 596 | + eps: float = 1e-5, |
| 597 | + bias: bool = True, |
| 598 | + norm_type: str = "layer_norm", |
| 599 | + ): |
| 600 | + super().__init__() |
| 601 | + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) |
| 602 | + if norm_type == "layer_norm": |
| 603 | + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) |
| 604 | + elif norm_type == "rms_norm": |
| 605 | + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) |
| 606 | + else: |
| 607 | + raise ValueError(f"unknown norm_type {norm_type}") |
| 608 | + |
| 609 | + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: |
| 610 | + # *** NO SiLU here *** |
| 611 | + emb = self.linear(conditioning_embedding.to(x.dtype)) |
| 612 | + scale, shift = torch.chunk(emb, 2, dim=1) |
| 613 | + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] |
| 614 | + return x |
| 615 | + |
| 616 | + |
587 | 617 | class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): |
588 | 618 | r""" |
589 | 619 | Args: |
@@ -666,7 +696,7 @@ def __init__( |
666 | 696 | ) |
667 | 697 |
|
668 | 698 | # 4. Output projection |
669 | | - self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False, use_silu=False) |
| 699 | + self.norm_out = _CogViewFinalAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) |
670 | 700 | self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True) |
671 | 701 |
|
672 | 702 | self.gradient_checkpointing = False |
|
0 commit comments