1111
1212
1313class MoonshineRotaryEmbedding (nn .Module ):
14- def __init__ (self , dim : int , max_position_embeddings : int = 512 , base : float = 10000.0 ):
14+ def __init__ (
15+ self , dim : int , max_position_embeddings : int = 512 , base : float = 10000.0
16+ ):
1517 super ().__init__ ()
1618 inv_freq = 1.0 / (base ** (mx .arange (0 , dim , 2 , dtype = mx .float32 ) / dim ))
1719 self ._inv_freq = inv_freq # shape: (dim // 2,)
1820 self ._dim = dim
1921 self ._max_seq_len = max_position_embeddings
2022
21- def __call__ (self , x : mx .array , position_ids : mx .array ) -> Tuple [mx .array , mx .array ]:
22- freqs = position_ids [:, :, None ].astype (mx .float32 ) * self ._inv_freq [None , None , :]
23+ def __call__ (
24+ self , x : mx .array , position_ids : mx .array
25+ ) -> Tuple [mx .array , mx .array ]:
26+ freqs = (
27+ position_ids [:, :, None ].astype (mx .float32 ) * self ._inv_freq [None , None , :]
28+ )
2329 emb = mx .concatenate ([freqs , freqs ], axis = - 1 )
2430 cos = mx .cos (emb )
2531 sin = mx .sin (emb )
@@ -185,7 +191,9 @@ def __init__(self, config: ModelConfig):
185191 self .input_layernorm = nn .LayerNorm (config .hidden_size , bias = False )
186192 self .post_attention_layernorm = nn .LayerNorm (config .hidden_size , bias = False )
187193
188- def __call__ (self , x : mx .array , position_ids : Optional [mx .array ] = None ) -> mx .array :
194+ def __call__ (
195+ self , x : mx .array , position_ids : Optional [mx .array ] = None
196+ ) -> mx .array :
189197 residual = x
190198 x = self .input_layernorm (x )
191199 x , _ = self .self_attn (x , position_ids = position_ids )
@@ -262,7 +270,10 @@ def __init__(self, config: ModelConfig):
262270 self .groupnorm = nn .GroupNorm (1 , dim )
263271 self .conv2 = nn .Conv1d (dim , 2 * dim , kernel_size = 7 , stride = 3 , bias = True )
264272 self .conv3 = nn .Conv1d (2 * dim , dim , kernel_size = 3 , stride = 2 , bias = True )
265- self .layers = [MoonshineEncoderLayer (config ) for _ in range (config .encoder_num_hidden_layers )]
273+ self .layers = [
274+ MoonshineEncoderLayer (config )
275+ for _ in range (config .encoder_num_hidden_layers )
276+ ]
266277 self .layer_norm = nn .LayerNorm (dim , bias = False )
267278
268279 def __call__ (self , audio : mx .array ) -> mx .array :
@@ -285,7 +296,10 @@ class MoonshineDecoder(nn.Module):
285296 def __init__ (self , config : ModelConfig ):
286297 super ().__init__ ()
287298 self .embed_tokens = nn .Embedding (config .vocab_size , config .hidden_size )
288- self .layers = [MoonshineDecoderLayer (config ) for _ in range (config .decoder_num_hidden_layers )]
299+ self .layers = [
300+ MoonshineDecoderLayer (config )
301+ for _ in range (config .decoder_num_hidden_layers )
302+ ]
289303 self .norm = nn .LayerNorm (config .hidden_size , bias = False )
290304
291305 def __call__ (
@@ -297,7 +311,9 @@ def __call__(
297311 x = self .embed_tokens (tokens )
298312
299313 if cache is None :
300- cache = [{"self_attn" : None , "cross_attn" : None } for _ in range (len (self .layers ))]
314+ cache = [
315+ {"self_attn" : None , "cross_attn" : None } for _ in range (len (self .layers ))
316+ ]
301317
302318 new_cache = []
303319 for i , layer in enumerate (self .layers ):
@@ -353,6 +369,7 @@ def generate(
353369
354370 if isinstance (audio , (str , Path )):
355371 from mlx_audio .stt .utils import load_audio
372+
356373 audio = load_audio (str (audio ), sr = self .sample_rate , dtype = dtype )
357374 elif not isinstance (audio , mx .array ):
358375 audio = mx .array (audio )
@@ -415,10 +432,10 @@ def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
415432 new_key = key
416433
417434 if key .startswith ("model.encoder." ):
418- new_key = key [len ("model." ):]
435+ new_key = key [len ("model." ) :]
419436
420437 elif key .startswith ("model.decoder." ):
421- new_key = key [len ("model." ):]
438+ new_key = key [len ("model." ) :]
422439
423440 elif key .startswith ("proj_out." ):
424441 if self .config .tie_word_embeddings :
@@ -440,6 +457,7 @@ def post_load_hook(cls, model: "Model", model_path: Path) -> "Model":
440457 model_path = Path (model_path )
441458 try :
442459 from transformers import AutoTokenizer
460+
443461 model ._tokenizer = AutoTokenizer .from_pretrained (str (model_path ))
444462 except Exception :
445463 pass
@@ -453,4 +471,5 @@ def from_pretrained(cls, path_or_repo: str, *, dtype: mx.Dtype = mx.float32):
453471 stacklevel = 2 ,
454472 )
455473 from mlx_audio .stt .utils import load
474+
456475 return load (path_or_repo )
0 commit comments