3535from mlx_audio .stt .utils import load_audio
3636from mlx_audio .utils import from_dict
3737
38+ def _is_multilingual_parakeet_v3 (vocabulary : list [str ] | None ) -> bool :
39+ if not vocabulary or "<|predict_lang|>" not in vocabulary :
40+ return False
41+ language_tokens = [
42+ token
43+ for token in vocabulary
44+ if token .startswith ("<|" )
45+ and token .endswith ("|>" )
46+ and len (token ) == 6
47+ and token [2 :4 ].isalpha ()
48+ ]
49+ return len (language_tokens ) >= 20
50+
51+
52+ def _is_non_english_language_hint (language : str | None ) -> bool :
53+ if language is None :
54+ return False
55+ hint = language .strip ().lower ()
56+ return hint not in {"" , "auto" , "english" , "en" }
57+
3858
3959@dataclass
4060class PreprocessArgs :
@@ -160,7 +180,33 @@ def __init__(self, preprocess_args: PreprocessArgs):
160180
161181 self .preprocessor_config = preprocess_args
162182
163- def decode (self , mel : mx .array ) -> list [AlignedResult ]:
183+ def _resolve_multilingual_chunking (
184+ self ,
185+ audio_length_seconds : float ,
186+ chunk_duration : float | None ,
187+ overlap_duration : float ,
188+ language : str | None ,
189+ ) -> tuple [float | None , float ]:
190+ vocabulary = getattr (self , "vocabulary" , None )
191+ if not _is_multilingual_parakeet_v3 (vocabulary ):
192+ return chunk_duration , overlap_duration
193+ if not _is_non_english_language_hint (language ):
194+ return chunk_duration , overlap_duration
195+ if audio_length_seconds <= 30 :
196+ return chunk_duration , overlap_duration
197+
198+ # Parakeet v3 drifts into English on larger windows for non-English
199+ # inputs. Keep multilingual windows short enough to preserve language
200+ # consistency when the caller provides a non-English language hint.
201+ target_chunk = 5.0
202+ target_overlap = 1.0
203+
204+ if chunk_duration is None or chunk_duration > target_chunk :
205+ return target_chunk , target_overlap
206+
207+ return chunk_duration , overlap_duration
208+
209+ def decode (self , mel : mx .array , * , language : str | None = None ) -> list [AlignedResult ]:
164210 """
165211 Decode mel spectrograms to produce transcriptions with the Parakeet model.
166212 Handles batches and single input. Uses greedy decoding.
@@ -169,10 +215,14 @@ def decode(self, mel: mx.array) -> list[AlignedResult]:
169215 raise NotImplementedError
170216
171217 def decode_chunk (
172- self , audio_data : mx .array , verbose : bool = False
218+ self ,
219+ audio_data : mx .array ,
220+ verbose : bool = False ,
221+ * ,
222+ language : str | None = None ,
173223 ) -> AlignedResult :
174224 mel = log_mel_spectrogram (audio_data , self .preprocessor_config )
175- result = self .decode (mel )[0 ]
225+ result = self .decode (mel , language = language )[0 ]
176226 if verbose :
177227 print (result .text )
178228 return result
@@ -186,6 +236,7 @@ def generate(
186236 overlap_duration : float = 15.0 ,
187237 chunk_callback : Optional [Callable ] = None ,
188238 stream : bool = False ,
239+ language : str | None = None ,
189240 ** kwargs ,
190241 ) -> AlignedResult | Generator [StreamingResult , None , None ]:
191242 """
@@ -214,6 +265,7 @@ def generate(
214265 chunk_duration = chunk_duration ,
215266 overlap_duration = overlap_duration ,
216267 verbose = verbose ,
268+ language = language ,
217269 ** kwargs ,
218270 )
219271
@@ -226,13 +278,16 @@ def generate(
226278 # mx.array input
227279 audio_data = audio .astype (dtype ) if audio .dtype != dtype else audio
228280
229- if chunk_duration is None :
230- return self .decode_chunk (audio_data , verbose )
231-
232281 audio_length_seconds = len (audio_data ) / self .preprocessor_config .sample_rate
282+ chunk_duration , overlap_duration = self ._resolve_multilingual_chunking (
283+ audio_length_seconds , chunk_duration , overlap_duration , language
284+ )
285+
286+ if chunk_duration is None :
287+ return self .decode_chunk (audio_data , verbose , language = language )
233288
234289 if audio_length_seconds <= chunk_duration :
235- return self .decode_chunk (audio_data , verbose )
290+ return self .decode_chunk (audio_data , verbose , language = language )
236291
237292 chunk_samples = int (chunk_duration * self .preprocessor_config .sample_rate )
238293 overlap_samples = int (overlap_duration * self .preprocessor_config .sample_rate )
@@ -251,7 +306,7 @@ def generate(
251306 chunk_audio = audio_data [start :end ]
252307 chunk_mel = log_mel_spectrogram (chunk_audio , self .preprocessor_config )
253308
254- chunk_result = self .decode (chunk_mel )[0 ]
309+ chunk_result = self .decode (chunk_mel , language = language )[0 ]
255310
256311 chunk_offset = start / self .preprocessor_config .sample_rate
257312 for sentence in chunk_result .sentences :
@@ -300,6 +355,7 @@ def stream_generate(
300355 chunk_duration : float = 5.0 ,
301356 overlap_duration : float = 1.0 ,
302357 verbose : bool = False ,
358+ language : str | None = None ,
303359 ** kwargs ,
304360 ) -> Generator [StreamingResult , None , None ]:
305361 """
@@ -351,7 +407,7 @@ def stream_generate(
351407 chunk_audio = audio_data [start :end ]
352408 chunk_mel = log_mel_spectrogram (chunk_audio , self .preprocessor_config )
353409
354- chunk_result = self .decode (chunk_mel )[0 ]
410+ chunk_result = self .decode (chunk_mel , language = language )[0 ]
355411
356412 # Adjust timestamps for chunk offset
357413 chunk_offset = start / sample_rate
@@ -508,7 +564,7 @@ def __init__(self, args: ParakeetTDTArgs):
508564 self .decoder = PredictNetwork (args .decoder )
509565 self .joint = JointNetwork (args .joint )
510566
511- def decode (self , mel : mx .array ) -> list [AlignedResult ]:
567+ def decode (self , mel : mx .array , * , language : str | None = None ) -> list [AlignedResult ]:
512568 """
513569 Generate with skip token logic for the Parakeet model, handling batches and single input. Uses greedy decoding.
514570 mel: [batch, sequence, mel_dim] or [sequence, mel_dim]
@@ -615,7 +671,7 @@ def __init__(self, args: ParakeetRNNTArgs):
615671 self .decoder = PredictNetwork (args .decoder )
616672 self .joint = JointNetwork (args .joint )
617673
618- def decode (self , mel : mx .array ) -> list [AlignedResult ]:
674+ def decode (self , mel : mx .array , * , language : str | None = None ) -> list [AlignedResult ]:
619675 """
620676 Generate with skip token logic for the Parakeet model, handling batches and single input. Uses greedy decoding.
621677 mel: [batch, sequence, mel_dim] or [sequence, mel_dim]
@@ -709,7 +765,7 @@ def __init__(self, args: ParakeetCTCArgs):
709765 self .encoder = Conformer (args .encoder )
710766 self .decoder = ConvASRDecoder (args .decoder )
711767
712- def decode (self , mel : mx .array ) -> list [AlignedResult ]:
768+ def decode (self , mel : mx .array , * , language : str | None = None ) -> list [AlignedResult ]:
713769 """
714770 Generate with CTC decoding for the Parakeet model, handling batches and single input. Uses greedy decoding.
715771 mel: [batch, sequence, mel_dim] or [sequence, mel_dim]
0 commit comments