|  | 
| 28 | 28 | from ..embeddings import CogView3CombinedTimestepSizeEmbeddings | 
| 29 | 29 | from ..modeling_outputs import Transformer2DModelOutput | 
| 30 | 30 | from ..modeling_utils import ModelMixin | 
| 31 |  | -from ..normalization import AdaLayerNormContinuous | 
|  | 31 | +from ..normalization import LayerNorm, RMSNorm | 
| 32 | 32 | 
 | 
| 33 | 33 | 
 | 
| 34 | 34 | logger = logging.get_logger(__name__)  # pylint: disable=invalid-name | 
| @@ -584,6 +584,38 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens | 
| 584 | 584 |         return (freqs.cos(), freqs.sin()) | 
| 585 | 585 | 
 | 
| 586 | 586 | 
 | 
|  | 587 | +class CogView4AdaLayerNormContinuous(nn.Module): | 
|  | 588 | +    """ | 
|  | 589 | +    CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the | 
|  | 590 | +    Linear on conditioning embedding. | 
|  | 591 | +    """ | 
|  | 592 | + | 
|  | 593 | +    def __init__( | 
|  | 594 | +        self, | 
|  | 595 | +        embedding_dim: int, | 
|  | 596 | +        conditioning_embedding_dim: int, | 
|  | 597 | +        elementwise_affine: bool = True, | 
|  | 598 | +        eps: float = 1e-5, | 
|  | 599 | +        bias: bool = True, | 
|  | 600 | +        norm_type: str = "layer_norm", | 
|  | 601 | +    ): | 
|  | 602 | +        super().__init__() | 
|  | 603 | +        self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) | 
|  | 604 | +        if norm_type == "layer_norm": | 
|  | 605 | +            self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) | 
|  | 606 | +        elif norm_type == "rms_norm": | 
|  | 607 | +            self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) | 
|  | 608 | +        else: | 
|  | 609 | +            raise ValueError(f"unknown norm_type {norm_type}") | 
|  | 610 | + | 
|  | 611 | +    def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: | 
|  | 612 | +        # *** NO SiLU here *** | 
|  | 613 | +        emb = self.linear(conditioning_embedding.to(x.dtype)) | 
|  | 614 | +        scale, shift = torch.chunk(emb, 2, dim=1) | 
|  | 615 | +        x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] | 
|  | 616 | +        return x | 
|  | 617 | + | 
|  | 618 | + | 
| 587 | 619 | class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): | 
| 588 | 620 |     r""" | 
| 589 | 621 |     Args: | 
| @@ -666,7 +698,7 @@ def __init__( | 
| 666 | 698 |         ) | 
| 667 | 699 | 
 | 
| 668 | 700 |         # 4. Output projection | 
| 669 |  | -        self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) | 
|  | 701 | +        self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) | 
| 670 | 702 |         self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True) | 
| 671 | 703 | 
 | 
| 672 | 704 |         self.gradient_checkpointing = False | 
|  | 
0 commit comments