@@ -113,9 +113,7 @@ def __init__(self, model_dim: int, rank: int, eps: float):
113113 self .scale_up = nn .Linear (rank , model_dim , bias = True )
114114 self .gate_up = nn .Linear (rank , model_dim , bias = True )
115115
116- def __call__ (
117- self , x : mx .array , cond_embed : mx .array
118- ) -> Tuple [mx .array , mx .array ]:
116+ def __call__ (self , x : mx .array , cond_embed : mx .array ) -> Tuple [mx .array , mx .array ]:
119117 shift , scale , gate = mx .split (cond_embed , 3 , axis = - 1 )
120118 shift = self .shift_up (self .shift_down (nn .silu (shift ))) + shift
121119 scale = self .scale_up (self .scale_down (nn .silu (scale ))) + scale
@@ -324,7 +322,9 @@ def __init__(self, dim: int, heads: int, mlp_hidden_dim: int, norm_eps: float):
324322 def __call__ (
325323 self , x : mx .array , mask : Optional [mx .array ], freqs_cis : RotaryCache
326324 ) -> mx .array :
327- x = x + self .attention (self .attention_norm (x ), key_mask = mask , freqs_cis = freqs_cis )
325+ x = x + self .attention (
326+ self .attention_norm (x ), key_mask = mask , freqs_cis = freqs_cis
327+ )
328328 x = x + self .mlp (self .mlp_norm (x ))
329329 return x
330330
@@ -394,9 +394,7 @@ def __init__(
394394 TextBlock (dim , heads , mlp_hidden , norm_eps ) for _ in range (num_layers )
395395 ]
396396
397- def __call__ (
398- self , latent : mx .array , mask : Optional [mx .array ] = None
399- ) -> mx .array :
397+ def __call__ (self , latent : mx .array , mask : Optional [mx .array ] = None ) -> mx .array :
400398 x = self .in_proj (latent ) / 6.0
401399 freqs_cis = precompute_freqs_cis (self .head_dim , x .shape [1 ])
402400 if mask is not None :
@@ -453,8 +451,13 @@ def __call__(
453451 ) -> mx .array :
454452 x_norm , attn_gate = self .attention_adaln (x , cond_embed )
455453 x = x + attn_gate * self .attention (
456- x_norm , text_mask , speaker_mask , freqs_cis ,
457- kv_cache_text , kv_cache_speaker , start_pos ,
454+ x_norm ,
455+ text_mask ,
456+ speaker_mask ,
457+ freqs_cis ,
458+ kv_cache_text ,
459+ kv_cache_speaker ,
460+ start_pos ,
458461 )
459462 x_norm , mlp_gate = self .mlp_adaln (x , cond_embed )
460463 x = x + mlp_gate * self .mlp (x_norm )
@@ -554,7 +557,9 @@ def build_kv_cache(
554557 speaker_state : mx .array ,
555558 ) -> Tuple [List [KVCache ], List [KVCache ]]:
556559 """Pre-compute per-layer text/speaker KV projections for fast sampling."""
557- kv_text = [block .attention .get_kv_cache_text (text_state ) for block in self .blocks ]
560+ kv_text = [
561+ block .attention .get_kv_cache_text (text_state ) for block in self .blocks
562+ ]
558563 kv_speaker = [
559564 block .attention .get_kv_cache_speaker (speaker_state ) for block in self .blocks
560565 ]
@@ -576,7 +581,9 @@ def forward_with_conditions(
576581 kv_speaker : Optional [List [KVCache ]] = None ,
577582 start_pos : int = 0 ,
578583 ) -> mx .array :
579- t_embed = get_timestep_embedding (t , self .cfg .timestep_embed_dim ).astype (x_t .dtype )
584+ t_embed = get_timestep_embedding (t , self .cfg .timestep_embed_dim ).astype (
585+ x_t .dtype
586+ )
580587 cond_embed = self .cond_module (t_embed )[:, None , :] # (B, 1, 3*model_dim)
581588
582589 x = self .in_proj (x_t )
@@ -594,8 +601,14 @@ def forward_with_conditions(
594601 else block .attention .get_kv_cache_speaker (speaker_state )
595602 )
596603 x = block (
597- x , cond_embed , text_mask , speaker_mask ,
598- freqs_cis , kv_t , kv_s , start_pos ,
604+ x ,
605+ cond_embed ,
606+ text_mask ,
607+ speaker_mask ,
608+ freqs_cis ,
609+ kv_t ,
610+ kv_s ,
611+ start_pos ,
599612 )
600613
601614 x = self .out_norm (x )
0 commit comments