2626from ..embeddings import MochiCombinedTimestepCaptionEmbedding , PatchEmbed
2727from ..modeling_outputs import Transformer2DModelOutput
2828from ..modeling_utils import ModelMixin
29- from ..normalization import AdaLayerNormContinuous , MochiRMSNormZero , RMSNorm
29+ from ..normalization import AdaLayerNormContinuous , LuminaLayerNormContinuous , MochiRMSNormZero , RMSNorm
3030
3131
3232logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -55,7 +55,14 @@ def __init__(
5555 if not context_pre_only :
5656 self .norm1_context = MochiRMSNormZero (dim , 4 * pooled_projection_dim )
5757 else :
58- self .norm1_context = nn .Linear (dim , pooled_projection_dim )
58+ self .norm1_context = LuminaLayerNormContinuous (
59+ embedding_dim = pooled_projection_dim ,
60+ conditioning_embedding_dim = dim ,
61+ eps = 1e-6 ,
62+ elementwise_affine = False ,
63+ norm_type = "rms_norm" ,
64+ out_dim = None ,
65+ )
5966
6067 self .attn1 = Attention (
6168 query_dim = dim ,
@@ -83,7 +90,9 @@ def __init__(
8390 self .ff = FeedForward (dim , inner_dim = self .ff_inner_dim , activation_fn = activation_fn , bias = False )
8491 self .ff_context = None
8592 if not context_pre_only :
86- self .ff_context = FeedForward (pooled_projection_dim , inner_dim = self .ff_context_inner_dim , activation_fn = activation_fn , bias = False )
93+ self .ff_context = FeedForward (
94+ pooled_projection_dim , inner_dim = self .ff_context_inner_dim , activation_fn = activation_fn , bias = False
95+ )
8796
8897 self .norm4 = RMSNorm (dim , eps = 1e-6 , elementwise_affine = False )
8998 self .norm4_context = RMSNorm (pooled_projection_dim , eps = 1e-56 , elementwise_affine = False )
@@ -102,7 +111,7 @@ def forward(
102111 encoder_hidden_states , temb
103112 )
104113 else :
105- norm_encoder_hidden_states = self .norm1_context (encoder_hidden_states )
114+ norm_encoder_hidden_states = self .norm1_context (encoder_hidden_states , temb )
106115
107116 attn_hidden_states , context_attn_hidden_states = self .attn1 (
108117 hidden_states = norm_hidden_states ,
@@ -112,7 +121,7 @@ def forward(
112121
113122 hidden_states = hidden_states + self .norm2 (attn_hidden_states ) * torch .tanh (gate_msa ).unsqueeze (1 )
114123 norm_hidden_states = self .norm3 (hidden_states ) * (1 + scale_mlp .unsqueeze (1 ))
115-
124+
116125 if not self .context_pre_only :
117126 encoder_hidden_states = encoder_hidden_states + self .norm2_context (
118127 context_attn_hidden_states
@@ -207,7 +216,9 @@ def forward(
207216 post_patch_height = height // p
208217 post_patch_width = width // p
209218
210- temb , encoder_hidden_states = self .time_embed (timestep , encoder_hidden_states , encoder_attention_mask , hidden_dtype = hidden_states .dtype )
219+ temb , encoder_hidden_states = self .time_embed (
220+ timestep , encoder_hidden_states , encoder_attention_mask , hidden_dtype = hidden_states .dtype
221+ )
211222
212223 hidden_states = hidden_states .permute (0 , 2 , 1 , 3 , 4 ).flatten (0 , 1 )
213224 hidden_states = self .patch_embed (hidden_states )
0 commit comments