Skip to content

Commit da65bbc

Browse files
authored
Merge branch 'main' into DC-AE-turbo
2 parents 58592bd + 76c809e commit da65bbc

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
2929
from ..modeling_outputs import Transformer2DModelOutput
3030
from ..modeling_utils import ModelMixin
31-
from ..normalization import AdaLayerNormContinuous
31+
from ..normalization import LayerNorm, RMSNorm
3232

3333

3434
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -584,6 +584,38 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
584584
return (freqs.cos(), freqs.sin())
585585

586586

587+
class CogView4AdaLayerNormContinuous(nn.Module):
588+
"""
589+
CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
590+
Linear on conditioning embedding.
591+
"""
592+
593+
def __init__(
594+
self,
595+
embedding_dim: int,
596+
conditioning_embedding_dim: int,
597+
elementwise_affine: bool = True,
598+
eps: float = 1e-5,
599+
bias: bool = True,
600+
norm_type: str = "layer_norm",
601+
):
602+
super().__init__()
603+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
604+
if norm_type == "layer_norm":
605+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
606+
elif norm_type == "rms_norm":
607+
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
608+
else:
609+
raise ValueError(f"unknown norm_type {norm_type}")
610+
611+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
612+
# *** NO SiLU here ***
613+
emb = self.linear(conditioning_embedding.to(x.dtype))
614+
scale, shift = torch.chunk(emb, 2, dim=1)
615+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
616+
return x
617+
618+
587619
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
588620
r"""
589621
Args:
@@ -666,7 +698,7 @@ def __init__(
666698
)
667699

668700
# 4. Output projection
669-
self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
701+
self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
670702
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
671703

672704
self.gradient_checkpointing = False

0 commit comments

Comments
 (0)