Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 70 additions & 12 deletions mlx_audio/stt/models/parakeet/parakeet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,27 @@
from mlx_audio.utils import from_dict


def _is_multilingual_parakeet_v3(vocabulary: list[str] | None) -> bool:
if not vocabulary or "<|predict_lang|>" not in vocabulary:
return False
language_tokens = [
token
for token in vocabulary
if token.startswith("<|")
and token.endswith("|>")
and len(token) == 6
and token[2:4].isalpha()
]
return len(language_tokens) >= 20


def _is_non_english_language_hint(language: str | None) -> bool:
if language is None:
return False
hint = language.strip().lower()
return hint not in {"", "auto", "english", "en"}


@dataclass
class TDTDecodingArgs:
model_type: str
Expand Down Expand Up @@ -138,7 +159,35 @@ def __init__(self, preprocess_args: PreprocessArgs):

self.preprocessor_config = preprocess_args

def decode(self, mel: mx.array) -> list[AlignedResult]:
def _resolve_multilingual_chunking(
self,
audio_length_seconds: float,
chunk_duration: float | None,
overlap_duration: float,
language: str | None,
) -> tuple[float | None, float]:
vocabulary = getattr(self, "vocabulary", None)
if not _is_multilingual_parakeet_v3(vocabulary):
return chunk_duration, overlap_duration
if not _is_non_english_language_hint(language):
return chunk_duration, overlap_duration
if chunk_duration is None:
return chunk_duration, overlap_duration
if audio_length_seconds <= 30:
return chunk_duration, overlap_duration

# Parakeet v3 drifts into English on larger windows for non-English
# inputs. Keep multilingual windows short enough to preserve language
# consistency when the caller provides a non-English language hint.
target_chunk = 5.0
target_overlap = 1.0

if chunk_duration is None or chunk_duration > target_chunk:
return target_chunk, target_overlap

return chunk_duration, overlap_duration

def decode(self, mel: mx.array, *, language: str | None = None) -> list[AlignedResult]:
"""
Decode mel spectrograms to produce transcriptions with the Parakeet model.
Handles batches and single input. Uses greedy decoding.
Expand All @@ -147,10 +196,14 @@ def decode(self, mel: mx.array) -> list[AlignedResult]:
raise NotImplementedError

def decode_chunk(
self, audio_data: mx.array, verbose: bool = False
self,
audio_data: mx.array,
verbose: bool = False,
*,
language: str | None = None,
) -> AlignedResult:
mel = log_mel_spectrogram(audio_data, self.preprocessor_config)
result = self.decode(mel)[0]
result = self.decode(mel, language=language)[0]
if verbose:
print(result.text)
return result
Expand All @@ -164,6 +217,7 @@ def generate(
overlap_duration: Optional[float] = None,
chunk_callback: Optional[Callable] = None,
stream: bool = False,
language: str | None = None,
**kwargs,
) -> AlignedResult | Generator[StreamingResult, None, None]:
"""
Expand Down Expand Up @@ -195,6 +249,7 @@ def generate(
chunk_duration=5.0 if chunk_duration is None else chunk_duration,
overlap_duration=1.0 if overlap_duration is None else overlap_duration,
verbose=verbose,
language=language,
**kwargs,
)

Expand All @@ -206,9 +261,13 @@ def generate(
else:
# mx.array input
audio_data = audio.astype(dtype) if audio.dtype != dtype else audio
audio_length_seconds = len(audio_data) / self.preprocessor_config.sample_rate
chunk_duration, overlap_duration = self._resolve_multilingual_chunking(
audio_length_seconds, chunk_duration, overlap_duration, language
)

if chunk_duration is None:
return self.decode_chunk(audio_data, verbose)
return self.decode_chunk(audio_data, verbose, language=language)

overlap_duration = 15.0 if overlap_duration is None else overlap_duration

Expand All @@ -218,10 +277,8 @@ def generate(
f"chunk_duration ({chunk_duration}s)."
)

audio_length_seconds = len(audio_data) / self.preprocessor_config.sample_rate

if audio_length_seconds <= chunk_duration:
return self.decode_chunk(audio_data, verbose)
return self.decode_chunk(audio_data, verbose, language=language)

chunk_samples = int(chunk_duration * self.preprocessor_config.sample_rate)
overlap_samples = int(overlap_duration * self.preprocessor_config.sample_rate)
Expand All @@ -240,7 +297,7 @@ def generate(
chunk_audio = audio_data[start:end]
chunk_mel = log_mel_spectrogram(chunk_audio, self.preprocessor_config)

chunk_result = self.decode(chunk_mel)[0]
chunk_result = self.decode(chunk_mel, language=language)[0]

chunk_offset = start / self.preprocessor_config.sample_rate
for sentence in chunk_result.sentences:
Expand Down Expand Up @@ -289,6 +346,7 @@ def stream_generate(
chunk_duration: float = 5.0,
overlap_duration: float = 1.0,
verbose: bool = False,
language: str | None = None,
**kwargs,
) -> Generator[StreamingResult, None, None]:
"""
Expand Down Expand Up @@ -346,7 +404,7 @@ def stream_generate(
chunk_audio = audio_data[start:end]
chunk_mel = log_mel_spectrogram(chunk_audio, self.preprocessor_config)

chunk_result = self.decode(chunk_mel)[0]
chunk_result = self.decode(chunk_mel, language=language)[0]

# Adjust timestamps for chunk offset
chunk_offset = start / sample_rate
Expand Down Expand Up @@ -503,7 +561,7 @@ def __init__(self, args: ParakeetTDTArgs):
self.decoder = PredictNetwork(args.decoder)
self.joint = JointNetwork(args.joint)

def decode(self, mel: mx.array) -> list[AlignedResult]:
def decode(self, mel: mx.array, *, language: str | None = None) -> list[AlignedResult]:
"""
Generate with skip token logic for the Parakeet model, handling batches and single input. Uses greedy decoding.
mel: [batch, sequence, mel_dim] or [sequence, mel_dim]
Expand Down Expand Up @@ -611,7 +669,7 @@ def __init__(self, args: ParakeetRNNTArgs):
self.decoder = PredictNetwork(args.decoder)
self.joint = JointNetwork(args.joint)

def decode(self, mel: mx.array) -> list[AlignedResult]:
def decode(self, mel: mx.array, *, language: str | None = None) -> list[AlignedResult]:
"""
Generate with skip token logic for the Parakeet model, handling batches and single input. Uses greedy decoding.
mel: [batch, sequence, mel_dim] or [sequence, mel_dim]
Expand Down Expand Up @@ -706,7 +764,7 @@ def __init__(self, args: ParakeetCTCArgs):
self.encoder = Conformer(args.encoder)
self.decoder = ConvASRDecoder(args.decoder)

def decode(self, mel: mx.array) -> list[AlignedResult]:
def decode(self, mel: mx.array, *, language: str | None = None) -> list[AlignedResult]:
"""
Generate with CTC decoding for the Parakeet model, handling batches and single input. Uses greedy decoding.
mel: [batch, sequence, mel_dim] or [sequence, mel_dim]
Expand Down
157 changes: 157 additions & 0 deletions mlx_audio/stt/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import ANY, MagicMock, PropertyMock, patch

import mlx.core as mx
import mlx.nn as nn
import numpy as np


Expand Down Expand Up @@ -471,6 +472,162 @@ def hf_hub_download_side_effect(repo_id_arg, filename_arg):
self.assertEqual(model.vocabulary, dummy_vocabulary)
self.assertEqual(model.durations, [0, 1, 2, 3])

def test_multilingual_v3_detection_requires_predict_lang_and_language_tokens(self):
from mlx_audio.stt.models.parakeet.parakeet import _is_multilingual_parakeet_v3

multilingual_vocab = ["<unk>", "<|predict_lang|>"] + [
f"<|{code}|>"
for code in (
"en",
"es",
"fr",
"de",
"it",
"pt",
"nl",
"ru",
"uk",
"pl",
"ro",
"hu",
"bg",
"hr",
"cs",
"da",
"et",
"fi",
"el",
"lv",
)
]
english_only_vocab = ["<unk>", "<|en|>"]

self.assertTrue(_is_multilingual_parakeet_v3(multilingual_vocab))
self.assertFalse(_is_multilingual_parakeet_v3(english_only_vocab))

def test_multilingual_chunking_clamps_non_english_hint(self):
from mlx_audio.stt.models.parakeet.audio import PreprocessArgs
from mlx_audio.stt.models.parakeet.parakeet import Model

class DummyParakeetModel(Model):
def __new__(cls):
return nn.Module.__new__(cls)

def __init__(self):
super().__init__(
PreprocessArgs(
sample_rate=16000,
normalize="per_feature",
window_size=0.02,
window_stride=0.01,
window="hann",
features=80,
n_fft=512,
dither=1e-5,
)
)
self.vocabulary = ["<unk>", "<|predict_lang|>"] + [
f"<|{code}|>"
for code in (
"en",
"es",
"fr",
"de",
"it",
"pt",
"nl",
"ru",
"uk",
"pl",
"ro",
"hu",
"bg",
"hr",
"cs",
"da",
"et",
"fi",
"el",
"lv",
)
]

def decode(self, mel: mx.array, *, language: str | None = None):
return []

model = DummyParakeetModel()
chunk, overlap = model._resolve_multilingual_chunking(
audio_length_seconds=403.0,
chunk_duration=120.0,
overlap_duration=15.0,
language="spanish",
)

self.assertEqual(chunk, 5.0)
self.assertEqual(overlap, 1.0)

def test_multilingual_chunking_preserves_english_and_auto(self):
from mlx_audio.stt.models.parakeet.audio import PreprocessArgs
from mlx_audio.stt.models.parakeet.parakeet import Model

class DummyParakeetModel(Model):
def __new__(cls):
return nn.Module.__new__(cls)

def __init__(self):
super().__init__(
PreprocessArgs(
sample_rate=16000,
normalize="per_feature",
window_size=0.02,
window_stride=0.01,
window="hann",
features=80,
n_fft=512,
dither=1e-5,
)
)
self.vocabulary = ["<unk>", "<|predict_lang|>"] + [
f"<|{code}|>"
for code in (
"en",
"es",
"fr",
"de",
"it",
"pt",
"nl",
"ru",
"uk",
"pl",
"ro",
"hu",
"bg",
"hr",
"cs",
"da",
"et",
"fi",
"el",
"lv",
)
]

def decode(self, mel: mx.array, *, language: str | None = None):
return []

model = DummyParakeetModel()

for hint in (None, "auto", "english", "en"):
chunk, overlap = model._resolve_multilingual_chunking(
audio_length_seconds=403.0,
chunk_duration=120.0,
overlap_duration=15.0,
language=hint,
)
self.assertEqual(chunk, 120.0)
self.assertEqual(overlap, 15.0)


class TestGLMASRModel(unittest.TestCase):
"""Tests for the GLM-ASR model."""
Expand Down