Skip to content

Commit 4df6101

Browse files
committed
run formatter
1 parent 5e895fc commit 4df6101

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

mlx_audio/stt/models/moonshine/moonshine.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,21 @@
1111

1212

1313
class MoonshineRotaryEmbedding(nn.Module):
14-
def __init__(self, dim: int, max_position_embeddings: int = 512, base: float = 10000.0):
14+
def __init__(
15+
self, dim: int, max_position_embeddings: int = 512, base: float = 10000.0
16+
):
1517
super().__init__()
1618
inv_freq = 1.0 / (base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim))
1719
self._inv_freq = inv_freq # shape: (dim // 2,)
1820
self._dim = dim
1921
self._max_seq_len = max_position_embeddings
2022

21-
def __call__(self, x: mx.array, position_ids: mx.array) -> Tuple[mx.array, mx.array]:
22-
freqs = position_ids[:, :, None].astype(mx.float32) * self._inv_freq[None, None, :]
23+
def __call__(
24+
self, x: mx.array, position_ids: mx.array
25+
) -> Tuple[mx.array, mx.array]:
26+
freqs = (
27+
position_ids[:, :, None].astype(mx.float32) * self._inv_freq[None, None, :]
28+
)
2329
emb = mx.concatenate([freqs, freqs], axis=-1)
2430
cos = mx.cos(emb)
2531
sin = mx.sin(emb)
@@ -185,7 +191,9 @@ def __init__(self, config: ModelConfig):
185191
self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
186192
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False)
187193

188-
def __call__(self, x: mx.array, position_ids: Optional[mx.array] = None) -> mx.array:
194+
def __call__(
195+
self, x: mx.array, position_ids: Optional[mx.array] = None
196+
) -> mx.array:
189197
residual = x
190198
x = self.input_layernorm(x)
191199
x, _ = self.self_attn(x, position_ids=position_ids)
@@ -262,7 +270,10 @@ def __init__(self, config: ModelConfig):
262270
self.groupnorm = nn.GroupNorm(1, dim)
263271
self.conv2 = nn.Conv1d(dim, 2 * dim, kernel_size=7, stride=3, bias=True)
264272
self.conv3 = nn.Conv1d(2 * dim, dim, kernel_size=3, stride=2, bias=True)
265-
self.layers = [MoonshineEncoderLayer(config) for _ in range(config.encoder_num_hidden_layers)]
273+
self.layers = [
274+
MoonshineEncoderLayer(config)
275+
for _ in range(config.encoder_num_hidden_layers)
276+
]
266277
self.layer_norm = nn.LayerNorm(dim, bias=False)
267278

268279
def __call__(self, audio: mx.array) -> mx.array:
@@ -285,7 +296,10 @@ class MoonshineDecoder(nn.Module):
285296
def __init__(self, config: ModelConfig):
286297
super().__init__()
287298
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
288-
self.layers = [MoonshineDecoderLayer(config) for _ in range(config.decoder_num_hidden_layers)]
299+
self.layers = [
300+
MoonshineDecoderLayer(config)
301+
for _ in range(config.decoder_num_hidden_layers)
302+
]
289303
self.norm = nn.LayerNorm(config.hidden_size, bias=False)
290304

291305
def __call__(
@@ -297,7 +311,9 @@ def __call__(
297311
x = self.embed_tokens(tokens)
298312

299313
if cache is None:
300-
cache = [{"self_attn": None, "cross_attn": None} for _ in range(len(self.layers))]
314+
cache = [
315+
{"self_attn": None, "cross_attn": None} for _ in range(len(self.layers))
316+
]
301317

302318
new_cache = []
303319
for i, layer in enumerate(self.layers):
@@ -353,6 +369,7 @@ def generate(
353369

354370
if isinstance(audio, (str, Path)):
355371
from mlx_audio.stt.utils import load_audio
372+
356373
audio = load_audio(str(audio), sr=self.sample_rate, dtype=dtype)
357374
elif not isinstance(audio, mx.array):
358375
audio = mx.array(audio)
@@ -415,10 +432,10 @@ def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
415432
new_key = key
416433

417434
if key.startswith("model.encoder."):
418-
new_key = key[len("model."):]
435+
new_key = key[len("model.") :]
419436

420437
elif key.startswith("model.decoder."):
421-
new_key = key[len("model."):]
438+
new_key = key[len("model.") :]
422439

423440
elif key.startswith("proj_out."):
424441
if self.config.tie_word_embeddings:
@@ -440,6 +457,7 @@ def post_load_hook(cls, model: "Model", model_path: Path) -> "Model":
440457
model_path = Path(model_path)
441458
try:
442459
from transformers import AutoTokenizer
460+
443461
model._tokenizer = AutoTokenizer.from_pretrained(str(model_path))
444462
except Exception:
445463
pass
@@ -453,4 +471,5 @@ def from_pretrained(cls, path_or_repo: str, *, dtype: mx.Dtype = mx.float32):
453471
stacklevel=2,
454472
)
455473
from mlx_audio.stt.utils import load
474+
456475
return load(path_or_repo)

mlx_audio/stt/models/moonshine/tests/test_moonshine.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
from mlx_audio.stt.models.moonshine.config import ModelConfig
66
from mlx_audio.stt.models.moonshine.moonshine import (
7+
Model,
78
MoonshineAttention,
89
MoonshineDecoder,
910
MoonshineDecoderLayer,
1011
MoonshineDecoderMLP,
1112
MoonshineEncoder,
1213
MoonshineEncoderLayer,
1314
MoonshineEncoderMLP,
14-
Model,
1515
)
1616

1717

@@ -198,7 +198,9 @@ def test_decoder_key_mapping(self):
198198
self.assertIn("decoder.layers.0.self_attn.q_proj.weight", sanitized)
199199

200200
def test_cross_attn_mapping(self):
201-
weights = {"model.decoder.layers.0.encoder_attn.q_proj.weight": mx.zeros((32, 32))}
201+
weights = {
202+
"model.decoder.layers.0.encoder_attn.q_proj.weight": mx.zeros((32, 32))
203+
}
202204
sanitized = self.model.sanitize(weights)
203205
self.assertIn("decoder.layers.0.encoder_attn.q_proj.weight", sanitized)
204206

0 commit comments

Comments
 (0)