|
28 | 28 | from ..embeddings import CogView3CombinedTimestepSizeEmbeddings |
29 | 29 | from ..modeling_outputs import Transformer2DModelOutput |
30 | 30 | from ..modeling_utils import ModelMixin |
31 | | -from ..normalization import AdaLayerNormContinuous |
| 31 | +from ..normalization import LayerNorm, RMSNorm |
32 | 32 |
|
33 | 33 |
|
34 | 34 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
@@ -584,6 +584,38 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens |
584 | 584 | return (freqs.cos(), freqs.sin()) |
585 | 585 |
|
586 | 586 |
|
| 587 | +class CogView4AdaLayerNormContinuous(nn.Module): |
| 588 | + """ |
| 589 | + CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the |
| 590 | + Linear on conditioning embedding. |
| 591 | + """ |
| 592 | + |
| 593 | + def __init__( |
| 594 | + self, |
| 595 | + embedding_dim: int, |
| 596 | + conditioning_embedding_dim: int, |
| 597 | + elementwise_affine: bool = True, |
| 598 | + eps: float = 1e-5, |
| 599 | + bias: bool = True, |
| 600 | + norm_type: str = "layer_norm", |
| 601 | + ): |
| 602 | + super().__init__() |
| 603 | + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) |
| 604 | + if norm_type == "layer_norm": |
| 605 | + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) |
| 606 | + elif norm_type == "rms_norm": |
| 607 | + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) |
| 608 | + else: |
| 609 | + raise ValueError(f"unknown norm_type {norm_type}") |
| 610 | + |
| 611 | + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: |
| 612 | + # *** NO SiLU here *** |
| 613 | + emb = self.linear(conditioning_embedding.to(x.dtype)) |
| 614 | + scale, shift = torch.chunk(emb, 2, dim=1) |
| 615 | + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] |
| 616 | + return x |
| 617 | + |
| 618 | + |
587 | 619 | class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): |
588 | 620 | r""" |
589 | 621 | Args: |
@@ -666,7 +698,7 @@ def __init__( |
666 | 698 | ) |
667 | 699 |
|
668 | 700 | # 4. Output projection |
669 | | - self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) |
| 701 | + self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) |
670 | 702 | self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True) |
671 | 703 |
|
672 | 704 | self.gradient_checkpointing = False |
|
0 commit comments