Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ for result in model.generate("Hello from MLX-Audio!", voice="af_heart"):
| **VibeVoice-ASR** | Microsoft's 9B ASR with diarization & timestamps | Multiple | [mlx-community/VibeVoice-ASR-bf16](https://huggingface.co/mlx-community/VibeVoice-ASR-bf16) |
| **Canary** | NVIDIA's multilingual ASR with translation | 25 EU + RU, UK | [README](mlx_audio/stt/models/canary/README.md) |
| **Moonshine** | Useful Sensors' lightweight ASR | EN | [README](mlx_audio/stt/models/moonshine/README.md) |
| **MMS** | Meta's massively multilingual ASR with adapters | 1000+ | [README](mlx_audio/stt/models/mms/README.md) |


### Voice Activity Detection / Speaker Diarization (VAD)
Expand Down
29 changes: 29 additions & 0 deletions mlx_audio/stt/models/mms/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# MMS ASR

MLX implementation of Meta's Massively Multilingual Speech (MMS) ASR model, supporting 1000+ languages through language-specific adapter layers on top of a shared wav2vec2 backbone.

## Available Models

| Model | Parameters | Languages | Description |
|-------|------------|-----------|-------------|
| [facebook/mms-1b-fl102](https://huggingface.co/facebook/mms-1b-fl102) | 1B | 102 | Finetuned on FLEURS |
| [facebook/mms-1b-all](https://huggingface.co/facebook/mms-1b-all) | 1B | 1162 | All supported languages |

## Python Usage

```python
from mlx_audio.stt import load

model = load("facebook/mms-1b-fl102")

result = model.generate("audio.wav")
print(result.text)
```

## Architecture

- Wav2Vec2 encoder with convolutional feature extractor
- Per-layer attention adapter modules for language adaptation
- CTC head with language-specific vocabulary
- Adapter weights loaded automatically from `adapter.{lang}.safetensors`
- 16kHz audio input with zero-mean unit-variance normalization
1 change: 1 addition & 0 deletions mlx_audio/stt/models/mms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mms import Model, ModelConfig
163 changes: 163 additions & 0 deletions mlx_audio/stt/models/mms/mms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import json
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import mlx.core as mx
import mlx.nn as nn

from ..base import STTOutput
from ..wav2vec.wav2vec import ModelConfig, Wav2Vec2Model


class Model(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
if isinstance(config, dict):
config = ModelConfig.from_dict(config)
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self._vocab = None
self._processor = None

@property
def sample_rate(self) -> int:
return 16000

def __call__(self, input_values: mx.array) -> mx.array:
outputs = self.wav2vec2(input_values)
logits = self.lm_head(outputs.last_hidden_state)
return logits

def _ctc_decode(self, logits: mx.array) -> List[List[int]]:
predictions = mx.argmax(logits, axis=-1)
batch_tokens = []
for b in range(predictions.shape[0]):
tokens = []
prev = -1
for t in range(predictions.shape[1]):
token = int(predictions[b, t])
if token != prev and token != 0:
tokens.append(token)
prev = token
batch_tokens.append(tokens)
return batch_tokens

def _tokens_to_text(self, tokens: List[int]) -> str:
if self._processor is not None:
return self._processor.decode(tokens)
if self._vocab is None:
return " ".join(str(t) for t in tokens)
return "".join(self._vocab.get(t, "") for t in tokens).replace("|", " ")

def generate(
self,
audio,
*,
verbose: bool = False,
dtype: mx.Dtype = mx.float32,
**kwargs,
) -> STTOutput:
kwargs.pop("generation_stream", None)
kwargs.pop("max_tokens", None)
kwargs.pop("temperature", None)
kwargs.pop("language", None)
kwargs.pop("source_lang", None)
kwargs.pop("target_lang", None)
kwargs.pop("stream", None)

start_time = time.time()

if isinstance(audio, (str, Path)):
from mlx_audio.stt.utils import load_audio

audio = load_audio(str(audio), sr=self.sample_rate, dtype=dtype)
elif not isinstance(audio, mx.array):
audio = mx.array(audio)

if audio.ndim == 1:
audio = audio[None, :]

if audio.dtype != dtype:
audio = audio.astype(dtype)

audio = (audio - mx.mean(audio, axis=-1, keepdims=True)) / (
mx.std(audio, axis=-1, keepdims=True) + 1e-7
)

logits = self(audio)
mx.eval(logits)

decoded = self._ctc_decode(logits)
text = self._tokens_to_text(decoded[0])

end_time = time.time()
total_time = end_time - start_time

if verbose:
print(f"Text: {text}")

return STTOutput(
text=text.strip(),
segments=[{"text": text.strip(), "start": 0.0, "end": 0.0}],
total_time=total_time,
)

def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
sanitized = {}
for k, v in weights.items():
if k.endswith(".conv.weight"):
v = v.swapaxes(1, 2)
if k.endswith(".conv.weight_v") or k.endswith(".conv.weight_g"):
v = v.swapaxes(1, 2)
if k.endswith(".parametrizations.weight.original0"):
k = k.replace(".parametrizations.weight.original0", ".weight_g")
v = v.swapaxes(1, 2)
if k.endswith(".parametrizations.weight.original1"):
k = k.replace(".parametrizations.weight.original1", ".weight_v")
v = v.swapaxes(1, 2)
if (
k.startswith("quantizer.")
or k.startswith("project_")
or k == "masked_spec_embed"
):
continue

sanitized[k] = v
return sanitized

@classmethod
def post_load_hook(cls, model: "Model", model_path: Path) -> "Model":
model_path = Path(model_path)

adapter_path = model_path / "adapter.eng.safetensors"
if not adapter_path.exists():
adapters = list(model_path.glob("adapter.*.safetensors"))
if adapters:
adapter_path = adapters[0]

if adapter_path.exists():
adapter_weights = mx.load(str(adapter_path))
sanitized = model.sanitize(adapter_weights)
model.load_weights(list(sanitized.items()), strict=False)

vocab_path = model_path / "vocab.json"
if vocab_path.exists():
with open(vocab_path) as f:
vocab = json.load(f)
if isinstance(next(iter(vocab.values())), dict):
lang_vocab = vocab.get(
"eng", vocab.get("en", next(iter(vocab.values())))
)
model._vocab = {v: k for k, v in lang_vocab.items()}
else:
model._vocab = {v: k for k, v in vocab.items()}

try:
from transformers import AutoProcessor

model._processor = AutoProcessor.from_pretrained(str(model_path))
except Exception:
model._processor = None
return model
Empty file.
135 changes: 135 additions & 0 deletions mlx_audio/stt/models/mms/tests/test_mms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import unittest

import mlx.core as mx

from mlx_audio.stt.models.mms.mms import Model
from mlx_audio.stt.models.wav2vec.wav2vec import ModelConfig


def _small_config():
return ModelConfig(
vocab_size=32,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=64,
conv_dim=(16, 16),
conv_stride=(2, 2),
conv_kernel=(4, 3),
num_feat_extract_layers=2,
num_conv_pos_embeddings=8,
num_conv_pos_embedding_groups=4,
)


class TestConfig(unittest.TestCase):

def test_defaults(self):
config = ModelConfig()
self.assertEqual(config.model_type, "wav2vec2")
self.assertEqual(config.hidden_size, 768)

def test_from_dict(self):
d = {"hidden_size": 1024, "num_hidden_layers": 24}
config = ModelConfig.from_dict(d)
self.assertEqual(config.hidden_size, 1024)


class TestCTCDecode(unittest.TestCase):

def test_greedy_decode(self):
config = _small_config()
model = Model(config)
logits = mx.zeros((1, 10, 32))
logits = logits.at[0, 0, 5].add(10.0)
logits = logits.at[0, 1, 5].add(10.0)
logits = logits.at[0, 2, 8].add(10.0)
logits = logits.at[0, 3, 0].add(10.0) # blank
logits = logits.at[0, 4, 8].add(10.0)
decoded = model._ctc_decode(logits)
self.assertEqual(decoded[0], [5, 8, 8])

def test_all_blanks(self):
config = _small_config()
model = Model(config)
logits = mx.zeros((1, 5, 32))
decoded = model._ctc_decode(logits)
self.assertEqual(decoded[0], [])


class TestTokensToText(unittest.TestCase):

def test_with_vocab(self):
config = _small_config()
model = Model(config)
model._vocab = {1: "h", 2: "e", 3: "l", 4: "o", 5: "|"}
text = model._tokens_to_text([1, 2, 3, 3, 4, 5, 1, 2])
self.assertEqual(text, "hello he")

def test_without_vocab(self):
config = _small_config()
model = Model(config)
text = model._tokens_to_text([1, 2, 3])
self.assertEqual(text, "1 2 3")


class TestModelSanitize(unittest.TestCase):

def setUp(self):
self.config = _small_config()
self.model = Model(self.config)

def test_keeps_lm_head(self):
weights = {"lm_head.weight": mx.zeros((32, 32))}
sanitized = self.model.sanitize(weights)
self.assertIn("lm_head.weight", sanitized)

def test_keeps_wav2vec2_prefix(self):
weights = {
"wav2vec2.encoder.layers.0.attention.q_proj.weight": mx.zeros((32, 32))
}
sanitized = self.model.sanitize(weights)
self.assertIn("wav2vec2.encoder.layers.0.attention.q_proj.weight", sanitized)

def test_conv_transpose(self):
weights = {
"wav2vec2.feature_extractor.conv_layers.0.conv.weight": mx.zeros((16, 1, 4))
}
sanitized = self.model.sanitize(weights)
key = "wav2vec2.feature_extractor.conv_layers.0.conv.weight"
self.assertEqual(sanitized[key].shape, (16, 4, 1))

def test_skips_quantizer(self):
weights = {"quantizer.weight_proj.weight": mx.zeros((32, 32))}
sanitized = self.model.sanitize(weights)
self.assertEqual(len(sanitized), 0)

def test_skips_masked_spec(self):
weights = {"masked_spec_embed": mx.zeros((32,))}
sanitized = self.model.sanitize(weights)
self.assertEqual(len(sanitized), 0)


class TestModel(unittest.TestCase):

def setUp(self):
self.config = _small_config()
self.model = Model(self.config)

def test_init(self):
self.assertIsNotNone(self.model.wav2vec2)
self.assertIsNotNone(self.model.lm_head)

def test_sample_rate(self):
self.assertEqual(self.model.sample_rate, 16000)

def test_forward(self):
audio = mx.random.normal((1, 320))
logits = self.model(audio)
mx.eval(logits)
self.assertEqual(logits.shape[0], 1)
self.assertEqual(logits.shape[2], 32)


if __name__ == "__main__":
unittest.main()
12 changes: 0 additions & 12 deletions mlx_audio/stt/models/moonshine/moonshine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -462,14 +461,3 @@ def post_load_hook(cls, model: "Model", model_path: Path) -> "Model":
except Exception:
pass
return model

@classmethod
def from_pretrained(cls, path_or_repo: str, *, dtype: mx.Dtype = mx.float32):
warnings.warn(
"Model.from_pretrained() is deprecated. Use mlx_audio.stt.load() instead.",
DeprecationWarning,
stacklevel=2,
)
from mlx_audio.stt.utils import load

return load(path_or_repo)
Loading