Skip to content

Commit b5c152c

Browse files
committed
fix(parakeet): clamp multilingual v3 chunk windows for non-English input
1 parent f7328a4 commit b5c152c

File tree

2 files changed

+225
-12
lines changed

2 files changed

+225
-12
lines changed

mlx_audio/stt/models/parakeet/parakeet.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,26 @@
3535
from mlx_audio.stt.utils import load_audio
3636
from mlx_audio.utils import from_dict
3737

38+
def _is_multilingual_parakeet_v3(vocabulary: list[str] | None) -> bool:
39+
if not vocabulary or "<|predict_lang|>" not in vocabulary:
40+
return False
41+
language_tokens = [
42+
token
43+
for token in vocabulary
44+
if token.startswith("<|")
45+
and token.endswith("|>")
46+
and len(token) == 6
47+
and token[2:4].isalpha()
48+
]
49+
return len(language_tokens) >= 20
50+
51+
52+
def _is_non_english_language_hint(language: str | None) -> bool:
53+
if language is None:
54+
return False
55+
hint = language.strip().lower()
56+
return hint not in {"", "auto", "english", "en"}
57+
3858

3959
@dataclass
4060
class PreprocessArgs:
@@ -160,7 +180,33 @@ def __init__(self, preprocess_args: PreprocessArgs):
160180

161181
self.preprocessor_config = preprocess_args
162182

163-
def decode(self, mel: mx.array) -> list[AlignedResult]:
183+
def _resolve_multilingual_chunking(
184+
self,
185+
audio_length_seconds: float,
186+
chunk_duration: float | None,
187+
overlap_duration: float,
188+
language: str | None,
189+
) -> tuple[float | None, float]:
190+
vocabulary = getattr(self, "vocabulary", None)
191+
if not _is_multilingual_parakeet_v3(vocabulary):
192+
return chunk_duration, overlap_duration
193+
if not _is_non_english_language_hint(language):
194+
return chunk_duration, overlap_duration
195+
if audio_length_seconds <= 30:
196+
return chunk_duration, overlap_duration
197+
198+
# Parakeet v3 drifts into English on larger windows for non-English
199+
# inputs. Keep multilingual windows short enough to preserve language
200+
# consistency when the caller provides a non-English language hint.
201+
target_chunk = 5.0
202+
target_overlap = 1.0
203+
204+
if chunk_duration is None or chunk_duration > target_chunk:
205+
return target_chunk, target_overlap
206+
207+
return chunk_duration, overlap_duration
208+
209+
def decode(self, mel: mx.array, *, language: str | None = None) -> list[AlignedResult]:
164210
"""
165211
Decode mel spectrograms to produce transcriptions with the Parakeet model.
166212
Handles batches and single input. Uses greedy decoding.
@@ -169,10 +215,14 @@ def decode(self, mel: mx.array) -> list[AlignedResult]:
169215
raise NotImplementedError
170216

171217
def decode_chunk(
172-
self, audio_data: mx.array, verbose: bool = False
218+
self,
219+
audio_data: mx.array,
220+
verbose: bool = False,
221+
*,
222+
language: str | None = None,
173223
) -> AlignedResult:
174224
mel = log_mel_spectrogram(audio_data, self.preprocessor_config)
175-
result = self.decode(mel)[0]
225+
result = self.decode(mel, language=language)[0]
176226
if verbose:
177227
print(result.text)
178228
return result
@@ -186,6 +236,7 @@ def generate(
186236
overlap_duration: float = 15.0,
187237
chunk_callback: Optional[Callable] = None,
188238
stream: bool = False,
239+
language: str | None = None,
189240
**kwargs,
190241
) -> AlignedResult | Generator[StreamingResult, None, None]:
191242
"""
@@ -214,6 +265,7 @@ def generate(
214265
chunk_duration=chunk_duration,
215266
overlap_duration=overlap_duration,
216267
verbose=verbose,
268+
language=language,
217269
**kwargs,
218270
)
219271

@@ -226,13 +278,16 @@ def generate(
226278
# mx.array input
227279
audio_data = audio.astype(dtype) if audio.dtype != dtype else audio
228280

229-
if chunk_duration is None:
230-
return self.decode_chunk(audio_data, verbose)
231-
232281
audio_length_seconds = len(audio_data) / self.preprocessor_config.sample_rate
282+
chunk_duration, overlap_duration = self._resolve_multilingual_chunking(
283+
audio_length_seconds, chunk_duration, overlap_duration, language
284+
)
285+
286+
if chunk_duration is None:
287+
return self.decode_chunk(audio_data, verbose, language=language)
233288

234289
if audio_length_seconds <= chunk_duration:
235-
return self.decode_chunk(audio_data, verbose)
290+
return self.decode_chunk(audio_data, verbose, language=language)
236291

237292
chunk_samples = int(chunk_duration * self.preprocessor_config.sample_rate)
238293
overlap_samples = int(overlap_duration * self.preprocessor_config.sample_rate)
@@ -251,7 +306,7 @@ def generate(
251306
chunk_audio = audio_data[start:end]
252307
chunk_mel = log_mel_spectrogram(chunk_audio, self.preprocessor_config)
253308

254-
chunk_result = self.decode(chunk_mel)[0]
309+
chunk_result = self.decode(chunk_mel, language=language)[0]
255310

256311
chunk_offset = start / self.preprocessor_config.sample_rate
257312
for sentence in chunk_result.sentences:
@@ -300,6 +355,7 @@ def stream_generate(
300355
chunk_duration: float = 5.0,
301356
overlap_duration: float = 1.0,
302357
verbose: bool = False,
358+
language: str | None = None,
303359
**kwargs,
304360
) -> Generator[StreamingResult, None, None]:
305361
"""
@@ -351,7 +407,7 @@ def stream_generate(
351407
chunk_audio = audio_data[start:end]
352408
chunk_mel = log_mel_spectrogram(chunk_audio, self.preprocessor_config)
353409

354-
chunk_result = self.decode(chunk_mel)[0]
410+
chunk_result = self.decode(chunk_mel, language=language)[0]
355411

356412
# Adjust timestamps for chunk offset
357413
chunk_offset = start / sample_rate
@@ -508,7 +564,7 @@ def __init__(self, args: ParakeetTDTArgs):
508564
self.decoder = PredictNetwork(args.decoder)
509565
self.joint = JointNetwork(args.joint)
510566

511-
def decode(self, mel: mx.array) -> list[AlignedResult]:
567+
def decode(self, mel: mx.array, *, language: str | None = None) -> list[AlignedResult]:
512568
"""
513569
Generate with skip token logic for the Parakeet model, handling batches and single input. Uses greedy decoding.
514570
mel: [batch, sequence, mel_dim] or [sequence, mel_dim]
@@ -615,7 +671,7 @@ def __init__(self, args: ParakeetRNNTArgs):
615671
self.decoder = PredictNetwork(args.decoder)
616672
self.joint = JointNetwork(args.joint)
617673

618-
def decode(self, mel: mx.array) -> list[AlignedResult]:
674+
def decode(self, mel: mx.array, *, language: str | None = None) -> list[AlignedResult]:
619675
"""
620676
Generate with skip token logic for the Parakeet model, handling batches and single input. Uses greedy decoding.
621677
mel: [batch, sequence, mel_dim] or [sequence, mel_dim]
@@ -709,7 +765,7 @@ def __init__(self, args: ParakeetCTCArgs):
709765
self.encoder = Conformer(args.encoder)
710766
self.decoder = ConvASRDecoder(args.decoder)
711767

712-
def decode(self, mel: mx.array) -> list[AlignedResult]:
768+
def decode(self, mel: mx.array, *, language: str | None = None) -> list[AlignedResult]:
713769
"""
714770
Generate with CTC decoding for the Parakeet model, handling batches and single input. Uses greedy decoding.
715771
mel: [batch, sequence, mel_dim] or [sequence, mel_dim]

mlx_audio/stt/tests/test_models.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from unittest.mock import ANY, MagicMock, PropertyMock, patch
55

66
import mlx.core as mx
7+
import mlx.nn as nn
78
import numpy as np
89

910

@@ -395,6 +396,162 @@ def hf_hub_download_side_effect(repo_id_arg, filename_arg):
395396
self.assertEqual(model.vocabulary, dummy_vocabulary)
396397
self.assertEqual(model.durations, [0, 1, 2, 3])
397398

399+
def test_multilingual_v3_detection_requires_predict_lang_and_language_tokens(self):
400+
from mlx_audio.stt.models.parakeet.parakeet import _is_multilingual_parakeet_v3
401+
402+
multilingual_vocab = ["<unk>", "<|predict_lang|>"] + [
403+
f"<|{code}|>"
404+
for code in (
405+
"en",
406+
"es",
407+
"fr",
408+
"de",
409+
"it",
410+
"pt",
411+
"nl",
412+
"ru",
413+
"uk",
414+
"pl",
415+
"ro",
416+
"hu",
417+
"bg",
418+
"hr",
419+
"cs",
420+
"da",
421+
"et",
422+
"fi",
423+
"el",
424+
"lv",
425+
)
426+
]
427+
english_only_vocab = ["<unk>", "<|en|>"]
428+
429+
self.assertTrue(_is_multilingual_parakeet_v3(multilingual_vocab))
430+
self.assertFalse(_is_multilingual_parakeet_v3(english_only_vocab))
431+
432+
def test_multilingual_chunking_clamps_non_english_hint(self):
433+
from mlx_audio.stt.models.parakeet.audio import PreprocessArgs
434+
from mlx_audio.stt.models.parakeet.parakeet import Model
435+
436+
class DummyParakeetModel(Model):
437+
def __new__(cls):
438+
return nn.Module.__new__(cls)
439+
440+
def __init__(self):
441+
super().__init__(
442+
PreprocessArgs(
443+
sample_rate=16000,
444+
normalize="per_feature",
445+
window_size=0.02,
446+
window_stride=0.01,
447+
window="hann",
448+
features=80,
449+
n_fft=512,
450+
dither=1e-5,
451+
)
452+
)
453+
self.vocabulary = ["<unk>", "<|predict_lang|>"] + [
454+
f"<|{code}|>"
455+
for code in (
456+
"en",
457+
"es",
458+
"fr",
459+
"de",
460+
"it",
461+
"pt",
462+
"nl",
463+
"ru",
464+
"uk",
465+
"pl",
466+
"ro",
467+
"hu",
468+
"bg",
469+
"hr",
470+
"cs",
471+
"da",
472+
"et",
473+
"fi",
474+
"el",
475+
"lv",
476+
)
477+
]
478+
479+
def decode(self, mel: mx.array, *, language: str | None = None):
480+
return []
481+
482+
model = DummyParakeetModel()
483+
chunk, overlap = model._resolve_multilingual_chunking(
484+
audio_length_seconds=403.0,
485+
chunk_duration=120.0,
486+
overlap_duration=15.0,
487+
language="spanish",
488+
)
489+
490+
self.assertEqual(chunk, 5.0)
491+
self.assertEqual(overlap, 1.0)
492+
493+
def test_multilingual_chunking_preserves_english_and_auto(self):
494+
from mlx_audio.stt.models.parakeet.audio import PreprocessArgs
495+
from mlx_audio.stt.models.parakeet.parakeet import Model
496+
497+
class DummyParakeetModel(Model):
498+
def __new__(cls):
499+
return nn.Module.__new__(cls)
500+
501+
def __init__(self):
502+
super().__init__(
503+
PreprocessArgs(
504+
sample_rate=16000,
505+
normalize="per_feature",
506+
window_size=0.02,
507+
window_stride=0.01,
508+
window="hann",
509+
features=80,
510+
n_fft=512,
511+
dither=1e-5,
512+
)
513+
)
514+
self.vocabulary = ["<unk>", "<|predict_lang|>"] + [
515+
f"<|{code}|>"
516+
for code in (
517+
"en",
518+
"es",
519+
"fr",
520+
"de",
521+
"it",
522+
"pt",
523+
"nl",
524+
"ru",
525+
"uk",
526+
"pl",
527+
"ro",
528+
"hu",
529+
"bg",
530+
"hr",
531+
"cs",
532+
"da",
533+
"et",
534+
"fi",
535+
"el",
536+
"lv",
537+
)
538+
]
539+
540+
def decode(self, mel: mx.array, *, language: str | None = None):
541+
return []
542+
543+
model = DummyParakeetModel()
544+
545+
for hint in (None, "auto", "english", "en"):
546+
chunk, overlap = model._resolve_multilingual_chunking(
547+
audio_length_seconds=403.0,
548+
chunk_duration=120.0,
549+
overlap_duration=15.0,
550+
language=hint,
551+
)
552+
self.assertEqual(chunk, 120.0)
553+
self.assertEqual(overlap, 15.0)
554+
398555

399556
class TestGLMASRModel(unittest.TestCase):
400557
"""Tests for the GLM-ASR model."""

0 commit comments

Comments
 (0)