1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- from operator import ipow
1716from typing import Any , Dict , Optional , Tuple
1817
1918import torch
20- from torch ._prims_common import is_low_precision_dtype
2119import torch .nn as nn
22- from transformers .tokenization_utils_base import import_protobuf_decode_error
2320
2421from ...configuration_utils import ConfigMixin , register_to_config
2522from ...utils import is_torch_version , logging
2623from ...utils .torch_utils import maybe_allow_in_graph
2724from ..attention import FeedForward
28- from ..attention_processor import Attention , MochiAttnProcessor2_0
25+ from ..attention_processor import MochiAttnProcessor2_0
2926from ..embeddings import MochiCombinedTimestepCaptionEmbedding , PatchEmbed
3027from ..modeling_outputs import Transformer2DModelOutput
3128from ..modeling_utils import ModelMixin
3229from ..normalization import (
3330 AdaLayerNormContinuous ,
34- LuminaLayerNormContinuous ,
3531)
3632
3733
3834logger = logging .get_logger (__name__ ) # pylint: disable=invalid-n
3935
4036
41- class FP32ModulatedRMSNorm (nn .Module ):
42- def __init__ (self , dim , eps : float , elementwise_affine : bool = True ):
37+ class MochiModulatedRMSNorm (nn .Module ):
38+ def __init__ (self , eps : float ):
4339 super ().__init__ ()
4440
4541 self .eps = eps
4642
4743 def forward (self , hidden_states , scale = None ):
44+ hidden_states_dtype = hidden_states .dtype
45+
4846 variance = hidden_states .to (torch .float32 ).pow (2 ).mean (- 1 , keepdim = True )
49- hidden_states = hidden_states .float ( ) * torch .rsqrt (variance + self .eps )
47+ hidden_states = hidden_states .to ( torch . float32 ) * torch .rsqrt (variance + self .eps )
5048
5149 if scale is not None :
5250 hidden_states = hidden_states * scale
5351
52+ hidden_states = hidden_states .to (hidden_states_dtype )
53+
54+ return hidden_states
55+
56+
57+ class MochiRMSNorm (nn .Module ):
58+ def __init__ (self , dim , eps : float , elementwise_affine = True ):
59+ super ().__init__ ()
60+
61+ self .eps = eps
62+ if elementwise_affine :
63+ self .weight = nn .Parameter (torch .ones (dim ))
64+ else :
65+ self .weight = None
66+
67+ def forward (self , hidden_states ):
68+ hidden_states_dtype = hidden_states .dtype
69+
70+ variance = hidden_states .to (torch .float32 ).pow (2 ).mean (- 1 , keepdim = True )
71+ hidden_states = hidden_states .to (torch .float32 ) * torch .rsqrt (variance + self .eps )
72+
73+ if self .weight is not None :
74+ # convert into half-precision if necessary
75+ if self .weight .dtype in [torch .float16 , torch .bfloat16 ]:
76+ hidden_states = hidden_states .to (self .weight .dtype )
77+ hidden_states = hidden_states * self .weight
78+
79+ hidden_states = hidden_states .to (hidden_states_dtype )
80+
5481 return hidden_states
5582
5683
@@ -59,49 +86,28 @@ def __init__(
5986 self ,
6087 embedding_dim : int ,
6188 conditioning_embedding_dim : int ,
62- # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
63- # because the output is immediately scaled and shifted by the projected conditioning embeddings.
64- # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
65- # However, this is how it was implemented in the original code, and it's rather likely you should
66- # set `elementwise_affine` to False.
67- elementwise_affine = True ,
6889 eps = 1e-5 ,
6990 bias = True ,
70- norm_type = "layer_norm" ,
71- out_dim : Optional [int ] = None ,
7291 ):
7392 super ().__init__ ()
7493
7594 # AdaLN
7695 self .silu = nn .SiLU ()
7796 self .linear_1 = nn .Linear (conditioning_embedding_dim , embedding_dim , bias = bias )
78-
79- if norm_type == "layer_norm" :
80- self .norm = LayerNorm (embedding_dim , eps , elementwise_affine , bias )
81- elif norm_type == "rms_norm" :
82- self .norm = FP32ModulatedRMSNorm (embedding_dim , eps = eps , elementwise_affine = elementwise_affine )
83- else :
84- raise ValueError (f"unknown norm_type { norm_type } " )
85-
86- self .linear_2 = None
87- if out_dim is not None :
88- self .linear_2 = nn .Linear (embedding_dim , out_dim , bias = bias )
97+ self .norm = MochiModulatedRMSNorm (eps = eps )
8998
9099 def forward (
91100 self ,
92101 x : torch .Tensor ,
93102 conditioning_embedding : torch .Tensor ,
94103 ) -> torch .Tensor :
95- output_dtype = x .dtype
96- # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
97- emb = self .linear_1 (self .silu (conditioning_embedding ).to (x .dtype ))
98- scale = emb
99- x = self .norm (x , (1 + scale .unsqueeze (1 ).float ()))
104+ input_dtype = x .dtype
100105
101- if self .linear_2 is not None :
102- x = self .linear_2 (x )
106+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
107+ scale = self .linear_1 (self .silu (conditioning_embedding ).to (x .dtype ))
108+ x = self .norm (x , (1 + scale .unsqueeze (1 ).to (torch .float32 )))
103109
104- return x .to (output_dtype )
110+ return x .to (input_dtype )
105111
106112
107113class MochiRMSNormZero (nn .Module ):
@@ -119,7 +125,7 @@ def __init__(
119125
120126 self .silu = nn .SiLU ()
121127 self .linear = nn .Linear (embedding_dim , hidden_dim )
122- self .norm = FP32ModulatedRMSNorm ( embedding_dim , eps = eps , elementwise_affine = elementwise_affine )
128+ self .norm = MochiModulatedRMSNorm ( eps = eps )
123129
124130 def forward (
125131 self , hidden_states : torch .Tensor , emb : torch .Tensor
@@ -129,12 +135,76 @@ def forward(
129135 emb = self .linear (self .silu (emb ))
130136 scale_msa , gate_msa , scale_mlp , gate_mlp = emb .chunk (4 , dim = 1 )
131137
132- hidden_states = self .norm (hidden_states , (1 + scale_msa [:, None ].float ( )))
138+ hidden_states = self .norm (hidden_states , (1 + scale_msa [:, None ].to ( torch . float32 )))
133139 hidden_states = hidden_states .to (hidden_states_dtype )
134140
135141 return hidden_states , gate_msa , scale_mlp , gate_mlp
136142
137143
144+ class MochiAttention (nn .Module ):
145+ def __init__ (
146+ self ,
147+ query_dim : int ,
148+ processor : Optional ["MochiAttnProcessor2_0" ],
149+ heads : int = 8 ,
150+ dim_head : int = 64 ,
151+ dropout : float = 0.0 ,
152+ bias : bool = False ,
153+ added_kv_proj_dim : Optional [int ] = None ,
154+ added_proj_bias : Optional [bool ] = True ,
155+ out_dim : int = None ,
156+ out_context_dim : int = None ,
157+ out_bias : bool = True ,
158+ context_pre_only : bool = False ,
159+ eps : float = 1e-5 ,
160+ ):
161+ super ().__init__ ()
162+ self .inner_dim = out_dim if out_dim is not None else dim_head * heads
163+ self .out_dim = out_dim if out_dim is not None else query_dim
164+ self .out_context_dim = out_context_dim if out_context_dim else query_dim
165+ self .context_pre_only = context_pre_only
166+
167+ self .heads = out_dim // dim_head if out_dim is not None else heads
168+
169+ self .norm_q = MochiRMSNorm (dim_head , eps )
170+ self .norm_k = MochiRMSNorm (dim_head , eps )
171+ self .norm_added_q = MochiRMSNorm (dim_head , eps )
172+ self .norm_added_k = MochiRMSNorm (dim_head , eps )
173+
174+ self .to_q = nn .Linear (query_dim , self .inner_dim , bias = bias )
175+ self .to_k = nn .Linear (query_dim , self .inner_dim , bias = bias )
176+ self .to_v = nn .Linear (query_dim , self .inner_dim , bias = bias )
177+
178+ self .add_k_proj = nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
179+ self .add_v_proj = nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
180+ if self .context_pre_only is not None :
181+ self .add_q_proj = nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
182+
183+ self .to_out = nn .ModuleList ([])
184+ self .to_out .append (nn .Linear (self .inner_dim , self .out_dim , bias = out_bias ))
185+ self .to_out .append (nn .Dropout (dropout ))
186+
187+ if not self .context_pre_only :
188+ self .to_add_out = nn .Linear (self .inner_dim , self .out_context_dim , bias = out_bias )
189+
190+ self .processor = processor
191+
192+ def forward (
193+ self ,
194+ hidden_states : torch .Tensor ,
195+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
196+ attention_mask : Optional [torch .Tensor ] = None ,
197+ ** kwargs ,
198+ ):
199+ return self .processor (
200+ self ,
201+ hidden_states ,
202+ encoder_hidden_states = encoder_hidden_states ,
203+ attention_mask = attention_mask ,
204+ ** kwargs ,
205+ )
206+
207+
138208@maybe_allow_in_graph
139209class MochiTransformerBlock (nn .Module ):
140210 r"""
@@ -183,34 +253,28 @@ def __init__(
183253 embedding_dim = pooled_projection_dim ,
184254 conditioning_embedding_dim = dim ,
185255 eps = eps ,
186- elementwise_affine = False ,
187- norm_type = "rms_norm" ,
188- out_dim = None ,
189256 )
190257
191- self .attn1 = Attention (
258+ self .attn1 = MochiAttention (
192259 query_dim = dim ,
193- cross_attention_dim = None ,
194260 heads = num_attention_heads ,
195261 dim_head = attention_head_dim ,
196262 bias = False ,
197- qk_norm = qk_norm ,
198263 added_kv_proj_dim = pooled_projection_dim ,
199264 added_proj_bias = False ,
200265 out_dim = dim ,
201266 out_context_dim = pooled_projection_dim ,
202267 context_pre_only = context_pre_only ,
203268 processor = MochiAttnProcessor2_0 (),
204269 eps = 1e-5 ,
205- elementwise_affine = True ,
206270 )
207271
208272 # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
209- self .norm2 = FP32ModulatedRMSNorm ( dim , eps = eps , elementwise_affine = False )
210- self .norm2_context = FP32ModulatedRMSNorm ( pooled_projection_dim , eps = eps , elementwise_affine = False )
273+ self .norm2 = MochiModulatedRMSNorm ( eps = eps )
274+ self .norm2_context = MochiModulatedRMSNorm ( eps = eps ) if not self . context_pre_only else None
211275
212- self .norm3 = FP32ModulatedRMSNorm ( dim , eps = eps , elementwise_affine = False )
213- self .norm3_context = FP32ModulatedRMSNorm ( pooled_projection_dim , eps = eps , elementwise_affine = False )
276+ self .norm3 = MochiModulatedRMSNorm ( eps )
277+ self .norm3_context = MochiModulatedRMSNorm ( eps = eps ) if not self . context_pre_only else None
214278
215279 self .ff = FeedForward (dim , inner_dim = self .ff_inner_dim , activation_fn = activation_fn , bias = False )
216280 self .ff_context = None
@@ -222,8 +286,8 @@ def __init__(
222286 bias = False ,
223287 )
224288
225- self .norm4 = FP32ModulatedRMSNorm ( dim , eps = eps , elementwise_affine = False )
226- self .norm4_context = FP32ModulatedRMSNorm ( pooled_projection_dim , eps = eps , elementwise_affine = False )
289+ self .norm4 = MochiModulatedRMSNorm ( eps = eps )
290+ self .norm4_context = MochiModulatedRMSNorm ( eps = eps )
227291
228292 def forward (
229293 self ,
@@ -249,26 +313,22 @@ def forward(
249313 attention_mask = joint_attention_mask ,
250314 )
251315
252- hidden_states = hidden_states + self .norm2 (attn_hidden_states , torch .tanh (gate_msa ).unsqueeze (1 )).to (
253- hidden_states .dtype
254- )
255- norm_hidden_states = self .norm3 (hidden_states , (1 + scale_mlp .unsqueeze (1 ).float ())).to (hidden_states .dtype )
316+ hidden_states = hidden_states + self .norm2 (attn_hidden_states , torch .tanh (gate_msa ).unsqueeze (1 ))
317+ norm_hidden_states = self .norm3 (hidden_states , (1 + scale_mlp .unsqueeze (1 ).to (torch .float32 )))
256318 ff_output = self .ff (norm_hidden_states )
257- hidden_states = hidden_states + self .norm4 (ff_output , torch .tanh (gate_mlp ).unsqueeze (1 )).to (
258- hidden_states .dtype
259- )
319+ hidden_states = hidden_states + self .norm4 (ff_output , torch .tanh (gate_mlp ).unsqueeze (1 ))
260320
261321 if not self .context_pre_only :
262322 encoder_hidden_states = encoder_hidden_states + self .norm2_context (
263323 context_attn_hidden_states , torch .tanh (enc_gate_msa ).unsqueeze (1 )
264- ). to ( encoder_hidden_states . dtype )
324+ )
265325 norm_encoder_hidden_states = self .norm3_context (
266- encoder_hidden_states , (1 + enc_scale_mlp .unsqueeze (1 ).float ( ))
267- ). to ( encoder_hidden_states . dtype )
326+ encoder_hidden_states , (1 + enc_scale_mlp .unsqueeze (1 ).to ( torch . float32 ))
327+ )
268328 context_ff_output = self .ff_context (norm_encoder_hidden_states )
269329 encoder_hidden_states = encoder_hidden_states + self .norm4_context (
270330 context_ff_output , torch .tanh (enc_gate_mlp ).unsqueeze (1 )
271- ). to ( encoder_hidden_states . dtype )
331+ )
272332
273333 return hidden_states , encoder_hidden_states
274334
0 commit comments