Skip to content

Commit 3c874c6

Browse files
authored
Merge pull request #553 from mm65x/add-mms-asr
add mms asr model
2 parents 3b659b1 + c4d7b07 commit 3c874c6

File tree

9 files changed

+355
-12
lines changed

9 files changed

+355
-12
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ for result in model.generate("Hello from MLX-Audio!", voice="af_heart"):
102102
| **VibeVoice-ASR** | Microsoft's 9B ASR with diarization & timestamps | Multiple | [mlx-community/VibeVoice-ASR-bf16](https://huggingface.co/mlx-community/VibeVoice-ASR-bf16) |
103103
| **Canary** | NVIDIA's multilingual ASR with translation | 25 EU + RU, UK | [README](mlx_audio/stt/models/canary/README.md) |
104104
| **Moonshine** | Useful Sensors' lightweight ASR | EN | [README](mlx_audio/stt/models/moonshine/README.md) |
105+
| **MMS** | Meta's massively multilingual ASR with adapters | 1000+ | [README](mlx_audio/stt/models/mms/README.md) |
105106

106107

107108
### Voice Activity Detection / Speaker Diarization (VAD)

mlx_audio/stt/models/mms/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# MMS ASR
2+
3+
MLX implementation of Meta's Massively Multilingual Speech (MMS) ASR model, supporting 1000+ languages through language-specific adapter layers on top of a shared wav2vec2 backbone.
4+
5+
## Available Models
6+
7+
| Model | Parameters | Languages | Description |
8+
|-------|------------|-----------|-------------|
9+
| [facebook/mms-1b-fl102](https://huggingface.co/facebook/mms-1b-fl102) | 1B | 102 | Finetuned on FLEURS |
10+
| [facebook/mms-1b-all](https://huggingface.co/facebook/mms-1b-all) | 1B | 1162 | All supported languages |
11+
12+
## Python Usage
13+
14+
```python
15+
from mlx_audio.stt import load
16+
17+
model = load("facebook/mms-1b-fl102")
18+
19+
result = model.generate("audio.wav")
20+
print(result.text)
21+
```
22+
23+
## Architecture
24+
25+
- Wav2Vec2 encoder with convolutional feature extractor
26+
- Per-layer attention adapter modules for language adaptation
27+
- CTC head with language-specific vocabulary
28+
- Adapter weights loaded automatically from `adapter.{lang}.safetensors`
29+
- 16kHz audio input with zero-mean unit-variance normalization
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .mms import Model, ModelConfig

mlx_audio/stt/models/mms/mms.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import json
2+
import time
3+
from pathlib import Path
4+
from typing import Dict, List, Optional, Tuple
5+
6+
import mlx.core as mx
7+
import mlx.nn as nn
8+
9+
from ..base import STTOutput
10+
from ..wav2vec.wav2vec import ModelConfig, Wav2Vec2Model
11+
12+
13+
class Model(nn.Module):
14+
def __init__(self, config: ModelConfig):
15+
super().__init__()
16+
if isinstance(config, dict):
17+
config = ModelConfig.from_dict(config)
18+
self.config = config
19+
self.wav2vec2 = Wav2Vec2Model(config)
20+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
21+
self._vocab = None
22+
self._processor = None
23+
24+
@property
25+
def sample_rate(self) -> int:
26+
return 16000
27+
28+
def __call__(self, input_values: mx.array) -> mx.array:
29+
outputs = self.wav2vec2(input_values)
30+
logits = self.lm_head(outputs.last_hidden_state)
31+
return logits
32+
33+
def _ctc_decode(self, logits: mx.array) -> List[List[int]]:
34+
predictions = mx.argmax(logits, axis=-1)
35+
batch_tokens = []
36+
for b in range(predictions.shape[0]):
37+
tokens = []
38+
prev = -1
39+
for t in range(predictions.shape[1]):
40+
token = int(predictions[b, t])
41+
if token != prev and token != 0:
42+
tokens.append(token)
43+
prev = token
44+
batch_tokens.append(tokens)
45+
return batch_tokens
46+
47+
def _tokens_to_text(self, tokens: List[int]) -> str:
48+
if self._processor is not None:
49+
return self._processor.decode(tokens)
50+
if self._vocab is None:
51+
return " ".join(str(t) for t in tokens)
52+
return "".join(self._vocab.get(t, "") for t in tokens).replace("|", " ")
53+
54+
def generate(
55+
self,
56+
audio,
57+
*,
58+
verbose: bool = False,
59+
dtype: mx.Dtype = mx.float32,
60+
**kwargs,
61+
) -> STTOutput:
62+
kwargs.pop("generation_stream", None)
63+
kwargs.pop("max_tokens", None)
64+
kwargs.pop("temperature", None)
65+
kwargs.pop("language", None)
66+
kwargs.pop("source_lang", None)
67+
kwargs.pop("target_lang", None)
68+
kwargs.pop("stream", None)
69+
70+
start_time = time.time()
71+
72+
if isinstance(audio, (str, Path)):
73+
from mlx_audio.stt.utils import load_audio
74+
75+
audio = load_audio(str(audio), sr=self.sample_rate, dtype=dtype)
76+
elif not isinstance(audio, mx.array):
77+
audio = mx.array(audio)
78+
79+
if audio.ndim == 1:
80+
audio = audio[None, :]
81+
82+
if audio.dtype != dtype:
83+
audio = audio.astype(dtype)
84+
85+
audio = (audio - mx.mean(audio, axis=-1, keepdims=True)) / (
86+
mx.std(audio, axis=-1, keepdims=True) + 1e-7
87+
)
88+
89+
logits = self(audio)
90+
mx.eval(logits)
91+
92+
decoded = self._ctc_decode(logits)
93+
text = self._tokens_to_text(decoded[0])
94+
95+
end_time = time.time()
96+
total_time = end_time - start_time
97+
98+
if verbose:
99+
print(f"Text: {text}")
100+
101+
return STTOutput(
102+
text=text.strip(),
103+
segments=[{"text": text.strip(), "start": 0.0, "end": 0.0}],
104+
total_time=total_time,
105+
)
106+
107+
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
108+
sanitized = {}
109+
for k, v in weights.items():
110+
if k.endswith(".conv.weight"):
111+
v = v.swapaxes(1, 2)
112+
if k.endswith(".conv.weight_v") or k.endswith(".conv.weight_g"):
113+
v = v.swapaxes(1, 2)
114+
if k.endswith(".parametrizations.weight.original0"):
115+
k = k.replace(".parametrizations.weight.original0", ".weight_g")
116+
v = v.swapaxes(1, 2)
117+
if k.endswith(".parametrizations.weight.original1"):
118+
k = k.replace(".parametrizations.weight.original1", ".weight_v")
119+
v = v.swapaxes(1, 2)
120+
if (
121+
k.startswith("quantizer.")
122+
or k.startswith("project_")
123+
or k == "masked_spec_embed"
124+
):
125+
continue
126+
127+
sanitized[k] = v
128+
return sanitized
129+
130+
@classmethod
131+
def post_load_hook(cls, model: "Model", model_path: Path) -> "Model":
132+
model_path = Path(model_path)
133+
134+
adapter_path = model_path / "adapter.eng.safetensors"
135+
if not adapter_path.exists():
136+
adapters = list(model_path.glob("adapter.*.safetensors"))
137+
if adapters:
138+
adapter_path = adapters[0]
139+
140+
if adapter_path.exists():
141+
adapter_weights = mx.load(str(adapter_path))
142+
sanitized = model.sanitize(adapter_weights)
143+
model.load_weights(list(sanitized.items()), strict=False)
144+
145+
vocab_path = model_path / "vocab.json"
146+
if vocab_path.exists():
147+
with open(vocab_path) as f:
148+
vocab = json.load(f)
149+
if isinstance(next(iter(vocab.values())), dict):
150+
lang_vocab = vocab.get(
151+
"eng", vocab.get("en", next(iter(vocab.values())))
152+
)
153+
model._vocab = {v: k for k, v in lang_vocab.items()}
154+
else:
155+
model._vocab = {v: k for k, v in vocab.items()}
156+
157+
try:
158+
from transformers import AutoProcessor
159+
160+
model._processor = AutoProcessor.from_pretrained(str(model_path))
161+
except Exception:
162+
model._processor = None
163+
return model

mlx_audio/stt/models/mms/tests/__init__.py

Whitespace-only changes.
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import unittest
2+
3+
import mlx.core as mx
4+
5+
from mlx_audio.stt.models.mms.mms import Model
6+
from mlx_audio.stt.models.wav2vec.wav2vec import ModelConfig
7+
8+
9+
def _small_config():
10+
return ModelConfig(
11+
vocab_size=32,
12+
hidden_size=32,
13+
num_hidden_layers=2,
14+
num_attention_heads=4,
15+
intermediate_size=64,
16+
conv_dim=(16, 16),
17+
conv_stride=(2, 2),
18+
conv_kernel=(4, 3),
19+
num_feat_extract_layers=2,
20+
num_conv_pos_embeddings=8,
21+
num_conv_pos_embedding_groups=4,
22+
)
23+
24+
25+
class TestConfig(unittest.TestCase):
26+
27+
def test_defaults(self):
28+
config = ModelConfig()
29+
self.assertEqual(config.model_type, "wav2vec2")
30+
self.assertEqual(config.hidden_size, 768)
31+
32+
def test_from_dict(self):
33+
d = {"hidden_size": 1024, "num_hidden_layers": 24}
34+
config = ModelConfig.from_dict(d)
35+
self.assertEqual(config.hidden_size, 1024)
36+
37+
38+
class TestCTCDecode(unittest.TestCase):
39+
40+
def test_greedy_decode(self):
41+
config = _small_config()
42+
model = Model(config)
43+
logits = mx.zeros((1, 10, 32))
44+
logits = logits.at[0, 0, 5].add(10.0)
45+
logits = logits.at[0, 1, 5].add(10.0)
46+
logits = logits.at[0, 2, 8].add(10.0)
47+
logits = logits.at[0, 3, 0].add(10.0) # blank
48+
logits = logits.at[0, 4, 8].add(10.0)
49+
decoded = model._ctc_decode(logits)
50+
self.assertEqual(decoded[0], [5, 8, 8])
51+
52+
def test_all_blanks(self):
53+
config = _small_config()
54+
model = Model(config)
55+
logits = mx.zeros((1, 5, 32))
56+
decoded = model._ctc_decode(logits)
57+
self.assertEqual(decoded[0], [])
58+
59+
60+
class TestTokensToText(unittest.TestCase):
61+
62+
def test_with_vocab(self):
63+
config = _small_config()
64+
model = Model(config)
65+
model._vocab = {1: "h", 2: "e", 3: "l", 4: "o", 5: "|"}
66+
text = model._tokens_to_text([1, 2, 3, 3, 4, 5, 1, 2])
67+
self.assertEqual(text, "hello he")
68+
69+
def test_without_vocab(self):
70+
config = _small_config()
71+
model = Model(config)
72+
text = model._tokens_to_text([1, 2, 3])
73+
self.assertEqual(text, "1 2 3")
74+
75+
76+
class TestModelSanitize(unittest.TestCase):
77+
78+
def setUp(self):
79+
self.config = _small_config()
80+
self.model = Model(self.config)
81+
82+
def test_keeps_lm_head(self):
83+
weights = {"lm_head.weight": mx.zeros((32, 32))}
84+
sanitized = self.model.sanitize(weights)
85+
self.assertIn("lm_head.weight", sanitized)
86+
87+
def test_keeps_wav2vec2_prefix(self):
88+
weights = {
89+
"wav2vec2.encoder.layers.0.attention.q_proj.weight": mx.zeros((32, 32))
90+
}
91+
sanitized = self.model.sanitize(weights)
92+
self.assertIn("wav2vec2.encoder.layers.0.attention.q_proj.weight", sanitized)
93+
94+
def test_conv_transpose(self):
95+
weights = {
96+
"wav2vec2.feature_extractor.conv_layers.0.conv.weight": mx.zeros((16, 1, 4))
97+
}
98+
sanitized = self.model.sanitize(weights)
99+
key = "wav2vec2.feature_extractor.conv_layers.0.conv.weight"
100+
self.assertEqual(sanitized[key].shape, (16, 4, 1))
101+
102+
def test_skips_quantizer(self):
103+
weights = {"quantizer.weight_proj.weight": mx.zeros((32, 32))}
104+
sanitized = self.model.sanitize(weights)
105+
self.assertEqual(len(sanitized), 0)
106+
107+
def test_skips_masked_spec(self):
108+
weights = {"masked_spec_embed": mx.zeros((32,))}
109+
sanitized = self.model.sanitize(weights)
110+
self.assertEqual(len(sanitized), 0)
111+
112+
113+
class TestModel(unittest.TestCase):
114+
115+
def setUp(self):
116+
self.config = _small_config()
117+
self.model = Model(self.config)
118+
119+
def test_init(self):
120+
self.assertIsNotNone(self.model.wav2vec2)
121+
self.assertIsNotNone(self.model.lm_head)
122+
123+
def test_sample_rate(self):
124+
self.assertEqual(self.model.sample_rate, 16000)
125+
126+
def test_forward(self):
127+
audio = mx.random.normal((1, 320))
128+
logits = self.model(audio)
129+
mx.eval(logits)
130+
self.assertEqual(logits.shape[0], 1)
131+
self.assertEqual(logits.shape[2], 32)
132+
133+
134+
if __name__ == "__main__":
135+
unittest.main()

mlx_audio/stt/models/moonshine/moonshine.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import time
2-
import warnings
32
from pathlib import Path
43
from typing import Any, Dict, List, Optional, Tuple, Union
54

@@ -462,14 +461,3 @@ def post_load_hook(cls, model: "Model", model_path: Path) -> "Model":
462461
except Exception:
463462
pass
464463
return model
465-
466-
@classmethod
467-
def from_pretrained(cls, path_or_repo: str, *, dtype: mx.Dtype = mx.float32):
468-
warnings.warn(
469-
"Model.from_pretrained() is deprecated. Use mlx_audio.stt.load() instead.",
470-
DeprecationWarning,
471-
stacklevel=2,
472-
)
473-
from mlx_audio.stt.utils import load
474-
475-
return load(path_or_repo)

0 commit comments

Comments
 (0)