diff --git a/mlx_audio/stt/models/parakeet/parakeet.py b/mlx_audio/stt/models/parakeet/parakeet.py index 60254c7a..056f1dc5 100644 --- a/mlx_audio/stt/models/parakeet/parakeet.py +++ b/mlx_audio/stt/models/parakeet/parakeet.py @@ -36,6 +36,27 @@ from mlx_audio.utils import from_dict +def _is_multilingual_parakeet_v3(vocabulary: list[str] | None) -> bool: + if not vocabulary or "<|predict_lang|>" not in vocabulary: + return False + language_tokens = [ + token + for token in vocabulary + if token.startswith("<|") + and token.endswith("|>") + and len(token) == 6 + and token[2:4].isalpha() + ] + return len(language_tokens) >= 20 + + +def _is_non_english_language_hint(language: str | None) -> bool: + if language is None: + return False + hint = language.strip().lower() + return hint not in {"", "auto", "english", "en"} + + @dataclass class TDTDecodingArgs: model_type: str @@ -138,7 +159,35 @@ def __init__(self, preprocess_args: PreprocessArgs): self.preprocessor_config = preprocess_args - def decode(self, mel: mx.array) -> list[AlignedResult]: + def _resolve_multilingual_chunking( + self, + audio_length_seconds: float, + chunk_duration: float | None, + overlap_duration: float, + language: str | None, + ) -> tuple[float | None, float]: + vocabulary = getattr(self, "vocabulary", None) + if not _is_multilingual_parakeet_v3(vocabulary): + return chunk_duration, overlap_duration + if not _is_non_english_language_hint(language): + return chunk_duration, overlap_duration + if chunk_duration is None: + return chunk_duration, overlap_duration + if audio_length_seconds <= 30: + return chunk_duration, overlap_duration + + # Parakeet v3 drifts into English on larger windows for non-English + # inputs. Keep multilingual windows short enough to preserve language + # consistency when the caller provides a non-English language hint. + target_chunk = 5.0 + target_overlap = 1.0 + + if chunk_duration is None or chunk_duration > target_chunk: + return target_chunk, target_overlap + + return chunk_duration, overlap_duration + + def decode(self, mel: mx.array, *, language: str | None = None) -> list[AlignedResult]: """ Decode mel spectrograms to produce transcriptions with the Parakeet model. Handles batches and single input. Uses greedy decoding. @@ -147,10 +196,14 @@ def decode(self, mel: mx.array) -> list[AlignedResult]: raise NotImplementedError def decode_chunk( - self, audio_data: mx.array, verbose: bool = False + self, + audio_data: mx.array, + verbose: bool = False, + *, + language: str | None = None, ) -> AlignedResult: mel = log_mel_spectrogram(audio_data, self.preprocessor_config) - result = self.decode(mel)[0] + result = self.decode(mel, language=language)[0] if verbose: print(result.text) return result @@ -164,6 +217,7 @@ def generate( overlap_duration: Optional[float] = None, chunk_callback: Optional[Callable] = None, stream: bool = False, + language: str | None = None, **kwargs, ) -> AlignedResult | Generator[StreamingResult, None, None]: """ @@ -195,6 +249,7 @@ def generate( chunk_duration=5.0 if chunk_duration is None else chunk_duration, overlap_duration=1.0 if overlap_duration is None else overlap_duration, verbose=verbose, + language=language, **kwargs, ) @@ -206,9 +261,13 @@ def generate( else: # mx.array input audio_data = audio.astype(dtype) if audio.dtype != dtype else audio + audio_length_seconds = len(audio_data) / self.preprocessor_config.sample_rate + chunk_duration, overlap_duration = self._resolve_multilingual_chunking( + audio_length_seconds, chunk_duration, overlap_duration, language + ) if chunk_duration is None: - return self.decode_chunk(audio_data, verbose) + return self.decode_chunk(audio_data, verbose, language=language) overlap_duration = 15.0 if overlap_duration is None else overlap_duration @@ -218,10 +277,8 @@ def generate( f"chunk_duration ({chunk_duration}s)." ) - audio_length_seconds = len(audio_data) / self.preprocessor_config.sample_rate - if audio_length_seconds <= chunk_duration: - return self.decode_chunk(audio_data, verbose) + return self.decode_chunk(audio_data, verbose, language=language) chunk_samples = int(chunk_duration * self.preprocessor_config.sample_rate) overlap_samples = int(overlap_duration * self.preprocessor_config.sample_rate) @@ -240,7 +297,7 @@ def generate( chunk_audio = audio_data[start:end] chunk_mel = log_mel_spectrogram(chunk_audio, self.preprocessor_config) - chunk_result = self.decode(chunk_mel)[0] + chunk_result = self.decode(chunk_mel, language=language)[0] chunk_offset = start / self.preprocessor_config.sample_rate for sentence in chunk_result.sentences: @@ -289,6 +346,7 @@ def stream_generate( chunk_duration: float = 5.0, overlap_duration: float = 1.0, verbose: bool = False, + language: str | None = None, **kwargs, ) -> Generator[StreamingResult, None, None]: """ @@ -346,7 +404,7 @@ def stream_generate( chunk_audio = audio_data[start:end] chunk_mel = log_mel_spectrogram(chunk_audio, self.preprocessor_config) - chunk_result = self.decode(chunk_mel)[0] + chunk_result = self.decode(chunk_mel, language=language)[0] # Adjust timestamps for chunk offset chunk_offset = start / sample_rate @@ -503,7 +561,7 @@ def __init__(self, args: ParakeetTDTArgs): self.decoder = PredictNetwork(args.decoder) self.joint = JointNetwork(args.joint) - def decode(self, mel: mx.array) -> list[AlignedResult]: + def decode(self, mel: mx.array, *, language: str | None = None) -> list[AlignedResult]: """ Generate with skip token logic for the Parakeet model, handling batches and single input. Uses greedy decoding. mel: [batch, sequence, mel_dim] or [sequence, mel_dim] @@ -611,7 +669,7 @@ def __init__(self, args: ParakeetRNNTArgs): self.decoder = PredictNetwork(args.decoder) self.joint = JointNetwork(args.joint) - def decode(self, mel: mx.array) -> list[AlignedResult]: + def decode(self, mel: mx.array, *, language: str | None = None) -> list[AlignedResult]: """ Generate with skip token logic for the Parakeet model, handling batches and single input. Uses greedy decoding. mel: [batch, sequence, mel_dim] or [sequence, mel_dim] @@ -706,7 +764,7 @@ def __init__(self, args: ParakeetCTCArgs): self.encoder = Conformer(args.encoder) self.decoder = ConvASRDecoder(args.decoder) - def decode(self, mel: mx.array) -> list[AlignedResult]: + def decode(self, mel: mx.array, *, language: str | None = None) -> list[AlignedResult]: """ Generate with CTC decoding for the Parakeet model, handling batches and single input. Uses greedy decoding. mel: [batch, sequence, mel_dim] or [sequence, mel_dim] diff --git a/mlx_audio/stt/tests/test_models.py b/mlx_audio/stt/tests/test_models.py index 0331c935..34b2d762 100644 --- a/mlx_audio/stt/tests/test_models.py +++ b/mlx_audio/stt/tests/test_models.py @@ -4,6 +4,7 @@ from unittest.mock import ANY, MagicMock, PropertyMock, patch import mlx.core as mx +import mlx.nn as nn import numpy as np @@ -471,6 +472,162 @@ def hf_hub_download_side_effect(repo_id_arg, filename_arg): self.assertEqual(model.vocabulary, dummy_vocabulary) self.assertEqual(model.durations, [0, 1, 2, 3]) + def test_multilingual_v3_detection_requires_predict_lang_and_language_tokens(self): + from mlx_audio.stt.models.parakeet.parakeet import _is_multilingual_parakeet_v3 + + multilingual_vocab = ["", "<|predict_lang|>"] + [ + f"<|{code}|>" + for code in ( + "en", + "es", + "fr", + "de", + "it", + "pt", + "nl", + "ru", + "uk", + "pl", + "ro", + "hu", + "bg", + "hr", + "cs", + "da", + "et", + "fi", + "el", + "lv", + ) + ] + english_only_vocab = ["", "<|en|>"] + + self.assertTrue(_is_multilingual_parakeet_v3(multilingual_vocab)) + self.assertFalse(_is_multilingual_parakeet_v3(english_only_vocab)) + + def test_multilingual_chunking_clamps_non_english_hint(self): + from mlx_audio.stt.models.parakeet.audio import PreprocessArgs + from mlx_audio.stt.models.parakeet.parakeet import Model + + class DummyParakeetModel(Model): + def __new__(cls): + return nn.Module.__new__(cls) + + def __init__(self): + super().__init__( + PreprocessArgs( + sample_rate=16000, + normalize="per_feature", + window_size=0.02, + window_stride=0.01, + window="hann", + features=80, + n_fft=512, + dither=1e-5, + ) + ) + self.vocabulary = ["", "<|predict_lang|>"] + [ + f"<|{code}|>" + for code in ( + "en", + "es", + "fr", + "de", + "it", + "pt", + "nl", + "ru", + "uk", + "pl", + "ro", + "hu", + "bg", + "hr", + "cs", + "da", + "et", + "fi", + "el", + "lv", + ) + ] + + def decode(self, mel: mx.array, *, language: str | None = None): + return [] + + model = DummyParakeetModel() + chunk, overlap = model._resolve_multilingual_chunking( + audio_length_seconds=403.0, + chunk_duration=120.0, + overlap_duration=15.0, + language="spanish", + ) + + self.assertEqual(chunk, 5.0) + self.assertEqual(overlap, 1.0) + + def test_multilingual_chunking_preserves_english_and_auto(self): + from mlx_audio.stt.models.parakeet.audio import PreprocessArgs + from mlx_audio.stt.models.parakeet.parakeet import Model + + class DummyParakeetModel(Model): + def __new__(cls): + return nn.Module.__new__(cls) + + def __init__(self): + super().__init__( + PreprocessArgs( + sample_rate=16000, + normalize="per_feature", + window_size=0.02, + window_stride=0.01, + window="hann", + features=80, + n_fft=512, + dither=1e-5, + ) + ) + self.vocabulary = ["", "<|predict_lang|>"] + [ + f"<|{code}|>" + for code in ( + "en", + "es", + "fr", + "de", + "it", + "pt", + "nl", + "ru", + "uk", + "pl", + "ro", + "hu", + "bg", + "hr", + "cs", + "da", + "et", + "fi", + "el", + "lv", + ) + ] + + def decode(self, mel: mx.array, *, language: str | None = None): + return [] + + model = DummyParakeetModel() + + for hint in (None, "auto", "english", "en"): + chunk, overlap = model._resolve_multilingual_chunking( + audio_length_seconds=403.0, + chunk_duration=120.0, + overlap_duration=15.0, + language=hint, + ) + self.assertEqual(chunk, 120.0) + self.assertEqual(overlap, 15.0) + class TestGLMASRModel(unittest.TestCase): """Tests for the GLM-ASR model."""