Skip to content

Commit 657ce8b

Browse files
feat(ASR): MLX Whisper Support for Apple Silicon (#2366)
* add mlx-whisper support * added mlx-whisper example and test. update docling cli to use MLX automatically if present. * fix pre-commit checks and added proper type safety * fixed linter issue * DCO Remediation Commit for Ken Steele <[email protected]> I, Ken Steele <[email protected]>, hereby add my Signed-off-by to this commit: a979a68 I, Ken Steele <[email protected]>, hereby add my Signed-off-by to this commit: 9827068 I, Ken Steele <[email protected]>, hereby add my Signed-off-by to this commit: ebbeb45 I, Ken Steele <[email protected]>, hereby add my Signed-off-by to this commit: 2f6fd3c Signed-off-by: Ken Steele <[email protected]> * fix unit tests and code coverage for CI * DCO Remediation Commit for Ken Steele <[email protected]> I, Ken Steele <[email protected]>, hereby add my Signed-off-by to this commit: 5e61bf1 Signed-off-by: Ken Steele <[email protected]> * fix CI example test - mlx_whisper_example.py defaults to tests/data/audio/sample_10s.mp3 if no args specified. Signed-off-by: Ken Steele <[email protected]> * refactor: centralize audio file extensions and MIME types in base_models.py - Move audio file extensions from CLI hardcoded set to FormatToExtensions[InputFormat.AUDIO] - Add support for additional audio formats: m4a, aac, ogg, flac, mp4, avi, mov - Update FormatToMimeType mapping to include MIME types for all audio formats - Update CLI auto-detection to use centralized FormatToExtensions mapping - Add comprehensive tests for audio file auto-detection and pipeline selection - Ensure explicit pipeline choices are not overridden by auto-detection Fixes issue where only .mp3 and .wav files were processed as audio despite CLI auto-detection working for all formats. The document converter now properly recognizes all audio formats through MIME type detection. Addresses review comments: - Centralizes audio extensions in base_models.py as suggested - Maintains existing auto-detection behavior while using centralized data - Adds proper test coverage for the audio detection functionality All examples and tests pass with the new centralized approach. All audio formats (mp3, wav, m4a, aac, ogg, flac, mp4, avi, mov) now work correctly. Signed-off-by: Ken Steele <[email protected]> * feat: address reviewer feedback - improve CLI auto-detection and add explicit model options Review feedback addressed: 1. Fix CLI auto-detection to only switch to ASR pipeline when ALL files are audio - Previously switched if ANY file was audio, now requires ALL files to be audio - Added warning for mixed file types with guidance to use --pipeline asr 2. Add explicit WHISPER_X_MLX and WHISPER_X_NATIVE model options - Users can now force specific implementations if desired - Auto-selecting models (WHISPER_BASE, etc.) still choose best for hardware - Added 12 new explicit model options: _MLX and _NATIVE variants for each size CLI now supports: - Auto-selecting: whisper_tiny, whisper_base, etc. (choose best for hardware) - Explicit MLX: whisper_tiny_mlx, whisper_base_mlx, etc. (force MLX) - Explicit Native: whisper_tiny_native, whisper_base_native, etc. (force native) Addresses reviewer comments from @dolfim-ibm Signed-off-by: Ken Steele <[email protected]> * DCO Remediation Commit for Ken Steele <[email protected]> I, Ken Steele <[email protected]>, hereby add my Signed-off-by to this commit: c60e72d I, Ken Steele <[email protected]>, hereby add my Signed-off-by to this commit: 9480331 I, Ken Steele <[email protected]>, hereby add my Signed-off-by to this commit: 21905e8 I, Ken Steele <[email protected]>, hereby add my Signed-off-by to this commit: 96c669d I, Ken Steele <[email protected]>, hereby add my Signed-off-by to this commit: 8371c06 Signed-off-by: Ken Steele <[email protected]> * test(asr): add coverage for MLX options, pipeline helpers, and VLM prompts - tests/test_asr_mlx_whisper.py: verify explicit MLX options (framework, repo ids) - tests/test_asr_pipeline.py: cover _has_text/_determine_status and backend support with proper InputDocument/NoOpBackend wiring - tests/test_interfaces.py: add BaseVlmPageModel.formulate_prompt tests (RAW/NONE/CHAT, invalid style), with minimal InlineVlmOptions scaffold Improves reliability of ASR and VLM components by validating configuration paths and helper logic. Signed-off-by: Ken Steele <[email protected]> * test(asr): broaden coverage for model selection, pipeline flows, and VLM prompts - tests/test_asr_mlx_whisper.py - Add MLX/native selector coverage across all Whisper sizes - Validate repo_id choices under MLX and Native paths - Cover fallback path when MPS unavailable and mlx_whisper missing - tests/test_asr_pipeline.py - Relax silent-audio assertion to accept PARTIAL_SUCCESS or SUCCESS - Force CPU native path in helper tests to avoid torch in device selection - Add language handling tests for native/MLX transcribe - Cover native run success (BytesIO) and failure (exception) branches - Cover MLX run success/failure branches with mocked transcribe - Add init path coverage with artifacts_path - tests/test_interfaces.py - Add focused VLM prompt tests (NONE/CHAT variants) Result: all tests passing with significantly improved coverage for ASR model selectors, pipeline execution paths, and VLM prompt formulation. Signed-off-by: Ken Steele <[email protected]> * simplify ASR model settings (no pipeline detection needed) Signed-off-by: Michele Dolfi <[email protected]> * clean up disk space in runners Signed-off-by: Michele Dolfi <[email protected]> --------- Signed-off-by: Ken Steele <[email protected]> Signed-off-by: Michele Dolfi <[email protected]> Co-authored-by: Michele Dolfi <[email protected]>
1 parent a5af082 commit 657ce8b

29 files changed

+2016
-71
lines changed

.github/workflows/checks.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,14 @@ jobs:
242242
~/.EasyOCR/
243243
key: models-cache
244244

245-
- name: Pre-download Models
246-
run: uv run python -c "import easyocr; reader = easyocr.Reader(['en', 'fr', 'de', 'es'])"
245+
- name: Free up disk space
246+
run: |
247+
df -h
248+
sudo rm -rf /usr/share/dotnet
249+
sudo rm -rf /usr/local/lib/android
250+
sudo rm -rf /opt/ghc
251+
sudo apt-get clean
252+
df -h
247253
248254
- name: Run examples
249255
run: |

docling/cli/main.py

Lines changed: 74 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,23 @@
3232
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
3333
from docling.datamodel.asr_model_specs import (
3434
WHISPER_BASE,
35+
WHISPER_BASE_MLX,
36+
WHISPER_BASE_NATIVE,
3537
WHISPER_LARGE,
38+
WHISPER_LARGE_MLX,
39+
WHISPER_LARGE_NATIVE,
3640
WHISPER_MEDIUM,
41+
WHISPER_MEDIUM_MLX,
42+
WHISPER_MEDIUM_NATIVE,
3743
WHISPER_SMALL,
44+
WHISPER_SMALL_MLX,
45+
WHISPER_SMALL_NATIVE,
3846
WHISPER_TINY,
47+
WHISPER_TINY_MLX,
48+
WHISPER_TINY_NATIVE,
3949
WHISPER_TURBO,
50+
WHISPER_TURBO_MLX,
51+
WHISPER_TURBO_NATIVE,
4052
AsrModelType,
4153
)
4254
from docling.datamodel.base_models import (
@@ -611,6 +623,7 @@ def convert( # noqa: C901
611623
ocr_options.psm = psm
612624

613625
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
626+
614627
# pipeline_options: PaginatedPipelineOptions
615628
pipeline_options: PipelineOptions
616629

@@ -747,42 +760,74 @@ def convert( # noqa: C901
747760
InputFormat.IMAGE: pdf_format_option,
748761
}
749762

750-
elif pipeline == ProcessingPipeline.ASR:
751-
pipeline_options = AsrPipelineOptions(
752-
# enable_remote_services=enable_remote_services,
753-
# artifacts_path = artifacts_path
754-
)
763+
# Set ASR options
764+
asr_pipeline_options = AsrPipelineOptions(
765+
accelerator_options=AcceleratorOptions(
766+
device=device,
767+
num_threads=num_threads,
768+
),
769+
# enable_remote_services=enable_remote_services,
770+
# artifacts_path = artifacts_path
771+
)
755772

756-
if asr_model == AsrModelType.WHISPER_TINY:
757-
pipeline_options.asr_options = WHISPER_TINY
758-
elif asr_model == AsrModelType.WHISPER_SMALL:
759-
pipeline_options.asr_options = WHISPER_SMALL
760-
elif asr_model == AsrModelType.WHISPER_MEDIUM:
761-
pipeline_options.asr_options = WHISPER_MEDIUM
762-
elif asr_model == AsrModelType.WHISPER_BASE:
763-
pipeline_options.asr_options = WHISPER_BASE
764-
elif asr_model == AsrModelType.WHISPER_LARGE:
765-
pipeline_options.asr_options = WHISPER_LARGE
766-
elif asr_model == AsrModelType.WHISPER_TURBO:
767-
pipeline_options.asr_options = WHISPER_TURBO
768-
else:
769-
_log.error(f"{asr_model} is not known")
770-
raise ValueError(f"{asr_model} is not known")
773+
# Auto-selecting models (choose best implementation for hardware)
774+
if asr_model == AsrModelType.WHISPER_TINY:
775+
asr_pipeline_options.asr_options = WHISPER_TINY
776+
elif asr_model == AsrModelType.WHISPER_SMALL:
777+
asr_pipeline_options.asr_options = WHISPER_SMALL
778+
elif asr_model == AsrModelType.WHISPER_MEDIUM:
779+
asr_pipeline_options.asr_options = WHISPER_MEDIUM
780+
elif asr_model == AsrModelType.WHISPER_BASE:
781+
asr_pipeline_options.asr_options = WHISPER_BASE
782+
elif asr_model == AsrModelType.WHISPER_LARGE:
783+
asr_pipeline_options.asr_options = WHISPER_LARGE
784+
elif asr_model == AsrModelType.WHISPER_TURBO:
785+
asr_pipeline_options.asr_options = WHISPER_TURBO
786+
787+
# Explicit MLX models (force MLX implementation)
788+
elif asr_model == AsrModelType.WHISPER_TINY_MLX:
789+
asr_pipeline_options.asr_options = WHISPER_TINY_MLX
790+
elif asr_model == AsrModelType.WHISPER_SMALL_MLX:
791+
asr_pipeline_options.asr_options = WHISPER_SMALL_MLX
792+
elif asr_model == AsrModelType.WHISPER_MEDIUM_MLX:
793+
asr_pipeline_options.asr_options = WHISPER_MEDIUM_MLX
794+
elif asr_model == AsrModelType.WHISPER_BASE_MLX:
795+
asr_pipeline_options.asr_options = WHISPER_BASE_MLX
796+
elif asr_model == AsrModelType.WHISPER_LARGE_MLX:
797+
asr_pipeline_options.asr_options = WHISPER_LARGE_MLX
798+
elif asr_model == AsrModelType.WHISPER_TURBO_MLX:
799+
asr_pipeline_options.asr_options = WHISPER_TURBO_MLX
800+
801+
# Explicit Native models (force native implementation)
802+
elif asr_model == AsrModelType.WHISPER_TINY_NATIVE:
803+
asr_pipeline_options.asr_options = WHISPER_TINY_NATIVE
804+
elif asr_model == AsrModelType.WHISPER_SMALL_NATIVE:
805+
asr_pipeline_options.asr_options = WHISPER_SMALL_NATIVE
806+
elif asr_model == AsrModelType.WHISPER_MEDIUM_NATIVE:
807+
asr_pipeline_options.asr_options = WHISPER_MEDIUM_NATIVE
808+
elif asr_model == AsrModelType.WHISPER_BASE_NATIVE:
809+
asr_pipeline_options.asr_options = WHISPER_BASE_NATIVE
810+
elif asr_model == AsrModelType.WHISPER_LARGE_NATIVE:
811+
asr_pipeline_options.asr_options = WHISPER_LARGE_NATIVE
812+
elif asr_model == AsrModelType.WHISPER_TURBO_NATIVE:
813+
asr_pipeline_options.asr_options = WHISPER_TURBO_NATIVE
771814

772-
_log.info(f"pipeline_options: {pipeline_options}")
815+
else:
816+
_log.error(f"{asr_model} is not known")
817+
raise ValueError(f"{asr_model} is not known")
773818

774-
audio_format_option = AudioFormatOption(
775-
pipeline_cls=AsrPipeline,
776-
pipeline_options=pipeline_options,
777-
)
819+
_log.info(f"ASR pipeline_options: {asr_pipeline_options}")
778820

779-
format_options = {
780-
InputFormat.AUDIO: audio_format_option,
781-
}
821+
audio_format_option = AudioFormatOption(
822+
pipeline_cls=AsrPipeline,
823+
pipeline_options=asr_pipeline_options,
824+
)
825+
format_options[InputFormat.AUDIO] = audio_format_option
782826

827+
# Common options for all pipelines
783828
if artifacts_path is not None:
784829
pipeline_options.artifacts_path = artifacts_path
785-
# audio_pipeline_options.artifacts_path = artifacts_path
830+
asr_pipeline_options.artifacts_path = artifacts_path
786831

787832
doc_converter = DocumentConverter(
788833
allowed_formats=from_formats,

0 commit comments

Comments
 (0)