|
32 | 32 | from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions |
33 | 33 | from docling.datamodel.asr_model_specs import ( |
34 | 34 | WHISPER_BASE, |
| 35 | + WHISPER_BASE_MLX, |
| 36 | + WHISPER_BASE_NATIVE, |
35 | 37 | WHISPER_LARGE, |
| 38 | + WHISPER_LARGE_MLX, |
| 39 | + WHISPER_LARGE_NATIVE, |
36 | 40 | WHISPER_MEDIUM, |
| 41 | + WHISPER_MEDIUM_MLX, |
| 42 | + WHISPER_MEDIUM_NATIVE, |
37 | 43 | WHISPER_SMALL, |
| 44 | + WHISPER_SMALL_MLX, |
| 45 | + WHISPER_SMALL_NATIVE, |
38 | 46 | WHISPER_TINY, |
| 47 | + WHISPER_TINY_MLX, |
| 48 | + WHISPER_TINY_NATIVE, |
39 | 49 | WHISPER_TURBO, |
| 50 | + WHISPER_TURBO_MLX, |
| 51 | + WHISPER_TURBO_NATIVE, |
40 | 52 | AsrModelType, |
41 | 53 | ) |
42 | 54 | from docling.datamodel.base_models import ( |
@@ -611,6 +623,7 @@ def convert( # noqa: C901 |
611 | 623 | ocr_options.psm = psm |
612 | 624 |
|
613 | 625 | accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device) |
| 626 | + |
614 | 627 | # pipeline_options: PaginatedPipelineOptions |
615 | 628 | pipeline_options: PipelineOptions |
616 | 629 |
|
@@ -747,42 +760,74 @@ def convert( # noqa: C901 |
747 | 760 | InputFormat.IMAGE: pdf_format_option, |
748 | 761 | } |
749 | 762 |
|
750 | | - elif pipeline == ProcessingPipeline.ASR: |
751 | | - pipeline_options = AsrPipelineOptions( |
752 | | - # enable_remote_services=enable_remote_services, |
753 | | - # artifacts_path = artifacts_path |
754 | | - ) |
| 763 | + # Set ASR options |
| 764 | + asr_pipeline_options = AsrPipelineOptions( |
| 765 | + accelerator_options=AcceleratorOptions( |
| 766 | + device=device, |
| 767 | + num_threads=num_threads, |
| 768 | + ), |
| 769 | + # enable_remote_services=enable_remote_services, |
| 770 | + # artifacts_path = artifacts_path |
| 771 | + ) |
755 | 772 |
|
756 | | - if asr_model == AsrModelType.WHISPER_TINY: |
757 | | - pipeline_options.asr_options = WHISPER_TINY |
758 | | - elif asr_model == AsrModelType.WHISPER_SMALL: |
759 | | - pipeline_options.asr_options = WHISPER_SMALL |
760 | | - elif asr_model == AsrModelType.WHISPER_MEDIUM: |
761 | | - pipeline_options.asr_options = WHISPER_MEDIUM |
762 | | - elif asr_model == AsrModelType.WHISPER_BASE: |
763 | | - pipeline_options.asr_options = WHISPER_BASE |
764 | | - elif asr_model == AsrModelType.WHISPER_LARGE: |
765 | | - pipeline_options.asr_options = WHISPER_LARGE |
766 | | - elif asr_model == AsrModelType.WHISPER_TURBO: |
767 | | - pipeline_options.asr_options = WHISPER_TURBO |
768 | | - else: |
769 | | - _log.error(f"{asr_model} is not known") |
770 | | - raise ValueError(f"{asr_model} is not known") |
| 773 | + # Auto-selecting models (choose best implementation for hardware) |
| 774 | + if asr_model == AsrModelType.WHISPER_TINY: |
| 775 | + asr_pipeline_options.asr_options = WHISPER_TINY |
| 776 | + elif asr_model == AsrModelType.WHISPER_SMALL: |
| 777 | + asr_pipeline_options.asr_options = WHISPER_SMALL |
| 778 | + elif asr_model == AsrModelType.WHISPER_MEDIUM: |
| 779 | + asr_pipeline_options.asr_options = WHISPER_MEDIUM |
| 780 | + elif asr_model == AsrModelType.WHISPER_BASE: |
| 781 | + asr_pipeline_options.asr_options = WHISPER_BASE |
| 782 | + elif asr_model == AsrModelType.WHISPER_LARGE: |
| 783 | + asr_pipeline_options.asr_options = WHISPER_LARGE |
| 784 | + elif asr_model == AsrModelType.WHISPER_TURBO: |
| 785 | + asr_pipeline_options.asr_options = WHISPER_TURBO |
| 786 | + |
| 787 | + # Explicit MLX models (force MLX implementation) |
| 788 | + elif asr_model == AsrModelType.WHISPER_TINY_MLX: |
| 789 | + asr_pipeline_options.asr_options = WHISPER_TINY_MLX |
| 790 | + elif asr_model == AsrModelType.WHISPER_SMALL_MLX: |
| 791 | + asr_pipeline_options.asr_options = WHISPER_SMALL_MLX |
| 792 | + elif asr_model == AsrModelType.WHISPER_MEDIUM_MLX: |
| 793 | + asr_pipeline_options.asr_options = WHISPER_MEDIUM_MLX |
| 794 | + elif asr_model == AsrModelType.WHISPER_BASE_MLX: |
| 795 | + asr_pipeline_options.asr_options = WHISPER_BASE_MLX |
| 796 | + elif asr_model == AsrModelType.WHISPER_LARGE_MLX: |
| 797 | + asr_pipeline_options.asr_options = WHISPER_LARGE_MLX |
| 798 | + elif asr_model == AsrModelType.WHISPER_TURBO_MLX: |
| 799 | + asr_pipeline_options.asr_options = WHISPER_TURBO_MLX |
| 800 | + |
| 801 | + # Explicit Native models (force native implementation) |
| 802 | + elif asr_model == AsrModelType.WHISPER_TINY_NATIVE: |
| 803 | + asr_pipeline_options.asr_options = WHISPER_TINY_NATIVE |
| 804 | + elif asr_model == AsrModelType.WHISPER_SMALL_NATIVE: |
| 805 | + asr_pipeline_options.asr_options = WHISPER_SMALL_NATIVE |
| 806 | + elif asr_model == AsrModelType.WHISPER_MEDIUM_NATIVE: |
| 807 | + asr_pipeline_options.asr_options = WHISPER_MEDIUM_NATIVE |
| 808 | + elif asr_model == AsrModelType.WHISPER_BASE_NATIVE: |
| 809 | + asr_pipeline_options.asr_options = WHISPER_BASE_NATIVE |
| 810 | + elif asr_model == AsrModelType.WHISPER_LARGE_NATIVE: |
| 811 | + asr_pipeline_options.asr_options = WHISPER_LARGE_NATIVE |
| 812 | + elif asr_model == AsrModelType.WHISPER_TURBO_NATIVE: |
| 813 | + asr_pipeline_options.asr_options = WHISPER_TURBO_NATIVE |
771 | 814 |
|
772 | | - _log.info(f"pipeline_options: {pipeline_options}") |
| 815 | + else: |
| 816 | + _log.error(f"{asr_model} is not known") |
| 817 | + raise ValueError(f"{asr_model} is not known") |
773 | 818 |
|
774 | | - audio_format_option = AudioFormatOption( |
775 | | - pipeline_cls=AsrPipeline, |
776 | | - pipeline_options=pipeline_options, |
777 | | - ) |
| 819 | + _log.info(f"ASR pipeline_options: {asr_pipeline_options}") |
778 | 820 |
|
779 | | - format_options = { |
780 | | - InputFormat.AUDIO: audio_format_option, |
781 | | - } |
| 821 | + audio_format_option = AudioFormatOption( |
| 822 | + pipeline_cls=AsrPipeline, |
| 823 | + pipeline_options=asr_pipeline_options, |
| 824 | + ) |
| 825 | + format_options[InputFormat.AUDIO] = audio_format_option |
782 | 826 |
|
| 827 | + # Common options for all pipelines |
783 | 828 | if artifacts_path is not None: |
784 | 829 | pipeline_options.artifacts_path = artifacts_path |
785 | | - # audio_pipeline_options.artifacts_path = artifacts_path |
| 830 | + asr_pipeline_options.artifacts_path = artifacts_path |
786 | 831 |
|
787 | 832 | doc_converter = DocumentConverter( |
788 | 833 | allowed_formats=from_formats, |
|
0 commit comments