2222from ...utils import logging
2323from ...utils .torch_utils import maybe_allow_in_graph
2424from ..attention import FeedForward
25- from ..attention_processor import Attention , FluxAttnProcessor2_0
25+ from ..attention_processor import Attention , MochiAttnProcessor2_0
2626from ..embeddings import MochiCombinedTimestepCaptionEmbedding , PatchEmbed
2727from ..modeling_outputs import Transformer2DModelOutput
2828from ..modeling_utils import ModelMixin
@@ -43,22 +43,23 @@ def __init__(
4343 qk_norm : str = "rms_norm" ,
4444 activation_fn : str = "swiglu" ,
4545 context_pre_only : bool = True ,
46+ eps : float = 1e-6 ,
4647 ) -> None :
4748 super ().__init__ ()
4849
4950 self .context_pre_only = context_pre_only
5051 self .ff_inner_dim = (4 * dim * 2 ) // 3
5152 self .ff_context_inner_dim = (4 * pooled_projection_dim * 2 ) // 3
5253
53- self .norm1 = MochiRMSNormZero (dim , 4 * dim )
54+ self .norm1 = MochiRMSNormZero (dim , 4 * dim , eps = eps , elementwise_affine = False )
5455
5556 if not context_pre_only :
56- self .norm1_context = MochiRMSNormZero (dim , 4 * pooled_projection_dim )
57+ self .norm1_context = MochiRMSNormZero (dim , 4 * pooled_projection_dim , eps = eps , elementwise_affine = False )
5758 else :
5859 self .norm1_context = LuminaLayerNormContinuous (
5960 embedding_dim = pooled_projection_dim ,
6061 conditioning_embedding_dim = dim ,
61- eps = 1e-6 ,
62+ eps = eps ,
6263 elementwise_affine = False ,
6364 norm_type = "rms_norm" ,
6465 out_dim = None ,
@@ -76,16 +77,16 @@ def __init__(
7677 out_dim = dim ,
7778 out_context_dim = pooled_projection_dim ,
7879 context_pre_only = context_pre_only ,
79- processor = FluxAttnProcessor2_0 (),
80- eps = 1e-6 ,
80+ processor = MochiAttnProcessor2_0 (),
81+ eps = eps ,
8182 elementwise_affine = True ,
8283 )
8384
84- self .norm2 = RMSNorm (dim , eps = 1e-6 , elementwise_affine = False )
85- self .norm2_context = RMSNorm (pooled_projection_dim , eps = 1e-6 , elementwise_affine = False )
85+ self .norm2 = RMSNorm (dim , eps = eps , elementwise_affine = False )
86+ self .norm2_context = RMSNorm (pooled_projection_dim , eps = eps , elementwise_affine = False )
8687
87- self .norm3 = RMSNorm (dim , eps = 1e-6 , elementwise_affine = False )
88- self .norm3_context = RMSNorm (pooled_projection_dim , eps = 1e-56 , elementwise_affine = False )
88+ self .norm3 = RMSNorm (dim , eps = eps , elementwise_affine = False )
89+ self .norm3_context = RMSNorm (pooled_projection_dim , eps = eps , elementwise_affine = False )
8990
9091 self .ff = FeedForward (dim , inner_dim = self .ff_inner_dim , activation_fn = activation_fn , bias = False )
9192 self .ff_context = None
@@ -94,8 +95,8 @@ def __init__(
9495 pooled_projection_dim , inner_dim = self .ff_context_inner_dim , activation_fn = activation_fn , bias = False
9596 )
9697
97- self .norm4 = RMSNorm (dim , eps = 1e-6 , elementwise_affine = False )
98- self .norm4_context = RMSNorm (pooled_projection_dim , eps = 1e-56 , elementwise_affine = False )
98+ self .norm4 = RMSNorm (dim , eps = eps , elementwise_affine = False )
99+ self .norm4_context = RMSNorm (pooled_projection_dim , eps = eps , elementwise_affine = False )
99100
100101 def forward (
101102 self ,
@@ -104,6 +105,7 @@ def forward(
104105 temb : torch .Tensor ,
105106 image_rotary_emb : Optional [torch .Tensor ] = None ,
106107 ) -> Tuple [torch .Tensor , torch .Tensor ]:
108+ breakpoint ()
107109 norm_hidden_states , gate_msa , scale_mlp , gate_mlp = self .norm1 (hidden_states , temb )
108110
109111 if not self .context_pre_only :
@@ -140,6 +142,40 @@ def forward(
140142 return hidden_states , encoder_hidden_states
141143
142144
145+ class MochiRoPE (nn .Module ):
146+ def __init__ (self , base_height : int = 192 , base_width : int = 192 , theta : float = 10000.0 ) -> None :
147+ super ().__init__ ()
148+
149+ self .target_area = base_height * base_width
150+
151+ def _centers (self , start , stop , num , device , dtype ) -> torch .Tensor :
152+ edges = torch .linspace (start , stop , num + 1 , device = device , dtype = dtype )
153+ return (edges [:- 1 ] + edges [1 :]) / 2
154+
155+ def _get_positions (self , num_frames : int , height : int , width : int , device : Optional [torch .device ] = None , dtype : Optional [torch .dtype ] = None ) -> torch .Tensor :
156+ scale = (self .target_area / (height * width )) ** 0.5
157+
158+ t = torch .arange (num_frames , device = device , dtype = dtype )
159+ h = self ._centers (- height * scale / 2 , height * scale / 2 , height , device , dtype )
160+ w = self ._centers (- width * scale / 2 , width * scale / 2 , width , device , dtype )
161+
162+ grid_t , grid_h , grid_w = torch .meshgrid (t , h , w , indexing = "ij" )
163+
164+ positions = torch .stack ([grid_t , grid_h , grid_w ], dim = - 1 ).view (- 1 , 3 )
165+ return positions
166+
167+ def _create_rope (self , freqs : torch .Tensor , pos : torch .Tensor ) -> torch .Tensor :
168+ freqs = torch .einsum ("nd,dhf->nhf" , pos , freqs )
169+ freqs_cos = torch .cos (freqs )
170+ freqs_sin = torch .sin (freqs )
171+ return freqs_cos , freqs_sin
172+
173+ def forward (self , pos_frequencies : torch .Tensor , num_frames : int , height : int , width : int , device : Optional [torch .device ] = None , dtype : Optional [torch .dtype ] = None ) -> Tuple [torch .Tensor , torch .Tensor ]:
174+ pos = self ._get_positions (num_frames , height , width , device , dtype )
175+ rope_cos , rope_sin = self ._create_rope (pos_frequencies , pos )
176+ return rope_cos , rope_sin
177+
178+
143179@maybe_allow_in_graph
144180class MochiTransformer3DModel (ModelMixin , ConfigMixin ):
145181 _supports_gradient_checkpointing = True
@@ -169,6 +205,7 @@ def __init__(
169205 patch_size = patch_size ,
170206 in_channels = in_channels ,
171207 embed_dim = inner_dim ,
208+ pos_embed_type = None ,
172209 )
173210
174211 self .time_embed = MochiCombinedTimestepCaptionEmbedding (
@@ -180,6 +217,7 @@ def __init__(
180217 )
181218
182219 self .pos_frequencies = nn .Parameter (torch .empty (3 , num_attention_heads , attention_head_dim // 2 ))
220+ self .rope = MochiRoPE ()
183221
184222 self .transformer_blocks = nn .ModuleList (
185223 [
@@ -207,7 +245,6 @@ def forward(
207245 encoder_hidden_states : torch .Tensor ,
208246 timestep : torch .LongTensor ,
209247 encoder_attention_mask : torch .Tensor ,
210- image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
211248 return_dict : bool = True ,
212249 ) -> torch .Tensor :
213250 batch_size , num_channels , num_frames , height , width = hidden_states .shape
@@ -224,6 +261,8 @@ def forward(
224261 hidden_states = self .patch_embed (hidden_states )
225262 hidden_states = hidden_states .unflatten (0 , (batch_size , - 1 )).flatten (1 , 2 )
226263
264+ image_rotary_emb = self .rope (self .pos_frequencies , num_frames , post_patch_height , post_patch_width , device = hidden_states .device , dtype = torch .float32 )
265+
227266 for i , block in enumerate (self .transformer_blocks ):
228267 hidden_states , encoder_hidden_states = block (
229268 hidden_states = hidden_states ,
0 commit comments