Skip to content

Commit 5fac1de

Browse files
authored
[Whisper] Fix lang code assignment (#549)
* Enhance language and task handling in Whisper model: Introduce optional parameters for language and task in the audio processing methods, improving flexibility and user control over transcription behavior. * format * Update version to 0.4.0 in version.py
1 parent 3da730e commit 5fac1de

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

mlx_audio/stt/models/whisper/whisper.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,8 @@ def generate(
770770
audio: Union[str, np.ndarray, mx.array],
771771
*,
772772
verbose: Optional[bool] = None,
773+
language: Optional[str] = None,
774+
task: str = "transcribe",
773775
chunk_duration: float = 1.0,
774776
stream: bool = False,
775777
generation_stream: bool = False,
@@ -851,7 +853,10 @@ def generate(
851853

852854
if stream:
853855
return self.generate_streaming(
854-
audio, chunk_duration=chunk_duration, **decode_options
856+
audio,
857+
chunk_duration=chunk_duration,
858+
language=language,
859+
task=task,
855860
)
856861

857862
decode_options.pop("max_tokens", None)
@@ -870,12 +875,12 @@ def generate(
870875
make_safe = lambda x: x
871876

872877
# Use shared language detection
873-
language = self._detect_language(mel, language=decode_options.get("language"))
874-
if decode_options.get("language") is None:
875-
if verbose:
876-
print(f"Detected language: {LANGUAGES[language].title()}")
878+
detected = language is None
879+
language = self._detect_language(mel, language=language)
880+
if detected and verbose:
881+
print(f"Detected language: {LANGUAGES[language].title()}")
877882
decode_options["language"] = language
878-
task: str = decode_options.get("task", "transcribe")
883+
decode_options["task"] = task
879884
tokenizer = self.get_tokenizer(language=language, task=task)
880885

881886
if isinstance(clip_timestamps, str):

mlx_audio/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.2"
1+
__version__ = "0.4.0"

0 commit comments

Comments
 (0)