Skip to content

Commit 06a54ec

Browse files
committed
run formatter
1 parent caa4cae commit 06a54ec

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

mlx_audio/stt/models/mms/mms.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77
import mlx.nn as nn
88

99
from ..base import STTOutput
10-
from ..wav2vec.wav2vec import (
11-
ModelConfig,
12-
Wav2Vec2Model,
13-
)
10+
from ..wav2vec.wav2vec import ModelConfig, Wav2Vec2Model
1411

1512

1613
class Model(nn.Module):
@@ -74,6 +71,7 @@ def generate(
7471

7572
if isinstance(audio, (str, Path)):
7673
from mlx_audio.stt.utils import load_audio
74+
7775
audio = load_audio(str(audio), sr=self.sample_rate, dtype=dtype)
7876
elif not isinstance(audio, mx.array):
7977
audio = mx.array(audio)
@@ -84,7 +82,9 @@ def generate(
8482
if audio.dtype != dtype:
8583
audio = audio.astype(dtype)
8684

87-
audio = (audio - mx.mean(audio, axis=-1, keepdims=True)) / (mx.std(audio, axis=-1, keepdims=True) + 1e-7)
85+
audio = (audio - mx.mean(audio, axis=-1, keepdims=True)) / (
86+
mx.std(audio, axis=-1, keepdims=True) + 1e-7
87+
)
8888

8989
logits = self(audio)
9090
mx.eval(logits)
@@ -147,13 +147,16 @@ def post_load_hook(cls, model: "Model", model_path: Path) -> "Model":
147147
with open(vocab_path) as f:
148148
vocab = json.load(f)
149149
if isinstance(next(iter(vocab.values())), dict):
150-
lang_vocab = vocab.get("eng", vocab.get("en", next(iter(vocab.values()))))
150+
lang_vocab = vocab.get(
151+
"eng", vocab.get("en", next(iter(vocab.values())))
152+
)
151153
model._vocab = {v: k for k, v in lang_vocab.items()}
152154
else:
153155
model._vocab = {v: k for k, v in vocab.items()}
154156

155157
try:
156158
from transformers import AutoProcessor
159+
157160
model._processor = AutoProcessor.from_pretrained(str(model_path))
158161
except Exception:
159162
model._processor = None

mlx_audio/stt/models/mms/tests/test_mms.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,16 @@ def test_keeps_lm_head(self):
8585
self.assertIn("lm_head.weight", sanitized)
8686

8787
def test_keeps_wav2vec2_prefix(self):
88-
weights = {"wav2vec2.encoder.layers.0.attention.q_proj.weight": mx.zeros((32, 32))}
88+
weights = {
89+
"wav2vec2.encoder.layers.0.attention.q_proj.weight": mx.zeros((32, 32))
90+
}
8991
sanitized = self.model.sanitize(weights)
9092
self.assertIn("wav2vec2.encoder.layers.0.attention.q_proj.weight", sanitized)
9193

9294
def test_conv_transpose(self):
93-
weights = {"wav2vec2.feature_extractor.conv_layers.0.conv.weight": mx.zeros((16, 1, 4))}
95+
weights = {
96+
"wav2vec2.feature_extractor.conv_layers.0.conv.weight": mx.zeros((16, 1, 4))
97+
}
9498
sanitized = self.model.sanitize(weights)
9599
key = "wav2vec2.feature_extractor.conv_layers.0.conv.weight"
96100
self.assertEqual(sanitized[key].shape, (16, 4, 1))

0 commit comments

Comments
 (0)