Skip to content

Commit 2c91234

Browse files
authored
feat: enrichment steps on all convert pipelines (incl docx, html, etc) (#2251)
* allow enrichment on all convert pipelines Signed-off-by: Michele Dolfi <[email protected]> * set options in CLI Signed-off-by: Michele Dolfi <[email protected]> --------- Signed-off-by: Michele Dolfi <[email protected]>
1 parent c696549 commit 2c91234

File tree

12 files changed

+234
-189
lines changed

12 files changed

+234
-189
lines changed

docling/cli/main.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from docling.datamodel.document import ConversionResult
4949
from docling.datamodel.pipeline_options import (
5050
AsrPipelineOptions,
51+
ConvertPipelineOptions,
5152
EasyOcrOptions,
5253
OcrOptions,
5354
PaginatedPipelineOptions,
@@ -71,8 +72,13 @@
7172
from docling.document_converter import (
7273
AudioFormatOption,
7374
DocumentConverter,
75+
ExcelFormatOption,
7476
FormatOption,
77+
HTMLFormatOption,
78+
MarkdownFormatOption,
7579
PdfFormatOption,
80+
PowerpointFormatOption,
81+
WordFormatOption,
7682
)
7783
from docling.models.factories import get_ocr_factory
7884
from docling.pipeline.asr_pipeline import AsrPipeline
@@ -626,10 +632,33 @@ def convert( # noqa: C901
626632
backend=MetsGbsDocumentBackend,
627633
)
628634

635+
# SimplePipeline options
636+
simple_format_option = ConvertPipelineOptions(
637+
do_picture_description=enrich_picture_description,
638+
do_picture_classification=enrich_picture_classes,
639+
)
640+
if artifacts_path is not None:
641+
simple_format_option.artifacts_path = artifacts_path
642+
629643
format_options = {
630644
InputFormat.PDF: pdf_format_option,
631645
InputFormat.IMAGE: pdf_format_option,
632646
InputFormat.METS_GBS: mets_gbs_format_option,
647+
InputFormat.DOCX: WordFormatOption(
648+
pipeline_options=simple_format_option
649+
),
650+
InputFormat.PPTX: PowerpointFormatOption(
651+
pipeline_options=simple_format_option
652+
),
653+
InputFormat.XLSX: ExcelFormatOption(
654+
pipeline_options=simple_format_option
655+
),
656+
InputFormat.HTML: HTMLFormatOption(
657+
pipeline_options=simple_format_option
658+
),
659+
InputFormat.MD: MarkdownFormatOption(
660+
pipeline_options=simple_format_option
661+
),
633662
}
634663

635664
elif pipeline == ProcessingPipeline.VLM:

docling/datamodel/pipeline_options.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,21 @@ class PipelineOptions(BaseOptions):
259259
accelerator_options: AcceleratorOptions = AcceleratorOptions()
260260
enable_remote_services: bool = False
261261
allow_external_plugins: bool = False
262+
artifacts_path: Optional[Union[Path, str]] = None
262263

263264

264-
class PaginatedPipelineOptions(PipelineOptions):
265-
artifacts_path: Optional[Union[Path, str]] = None
265+
class ConvertPipelineOptions(PipelineOptions):
266+
"""Base convert pipeline options."""
267+
268+
do_picture_classification: bool = False # True: classify pictures in documents
269+
270+
do_picture_description: bool = False # True: run describe pictures in documents
271+
picture_description_options: PictureDescriptionBaseOptions = (
272+
smolvlm_picture_description
273+
)
266274

275+
276+
class PaginatedPipelineOptions(ConvertPipelineOptions):
267277
images_scale: float = 1.0
268278
generate_page_images: bool = False
269279
generate_picture_images: bool = False
@@ -295,13 +305,11 @@ class LayoutOptions(BaseModel):
295305

296306
class AsrPipelineOptions(PipelineOptions):
297307
asr_options: Union[InlineAsrOptions] = asr_model_specs.WHISPER_TINY
298-
artifacts_path: Optional[Union[Path, str]] = None
299308

300309

301310
class VlmExtractionPipelineOptions(PipelineOptions):
302311
"""Options for extraction pipeline."""
303312

304-
artifacts_path: Optional[Union[Path, str]] = None
305313
vlm_options: Union[InlineVlmOptions] = NU_EXTRACT_2B_TRANSFORMERS
306314

307315

@@ -312,18 +320,13 @@ class PdfPipelineOptions(PaginatedPipelineOptions):
312320
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
313321
do_code_enrichment: bool = False # True: perform code OCR
314322
do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code
315-
do_picture_classification: bool = False # True: classify pictures in documents
316-
do_picture_description: bool = False # True: run describe pictures in documents
317323
force_backend_text: bool = (
318324
False # (To be used with vlms, or other generative models)
319325
)
320326
# If True, text from backend will be used instead of generated text
321327

322328
table_structure_options: TableStructureOptions = TableStructureOptions()
323329
ocr_options: OcrOptions = EasyOcrOptions()
324-
picture_description_options: PictureDescriptionBaseOptions = (
325-
smolvlm_picture_description
326-
)
327330
layout_options: LayoutOptions = LayoutOptions()
328331

329332
images_scale: float = 1.0

docling/models/base_model.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
from typing import Any, Generic, Optional, Protocol, Type, Union
55

66
import numpy as np
7-
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
7+
from docling_core.types.doc import (
8+
BoundingBox,
9+
DocItem,
10+
DoclingDocument,
11+
NodeItem,
12+
PictureItem,
13+
)
814
from PIL.Image import Image
915
from typing_extensions import TypeVar
1016

@@ -164,8 +170,17 @@ def prepare_element(
164170
return None
165171

166172
assert isinstance(element, DocItem)
167-
element_prov = element.prov[0]
168173

174+
# Allow the case of documents without page images but embedded images (e.g. Word and HTML docs)
175+
if len(element.prov) == 0 and isinstance(element, PictureItem):
176+
embedded_im = element.get_image(conv_res.document)
177+
if embedded_im is not None:
178+
return ItemAndImageEnrichmentElement(item=element, image=embedded_im)
179+
else:
180+
return None
181+
182+
# Crop the image form the page
183+
element_prov = element.prov[0]
169184
bbox = element_prov.bbox
170185
width = bbox.r - bbox.l
171186
height = bbox.t - bbox.b
@@ -183,4 +198,14 @@ def prepare_element(
183198
cropped_image = conv_res.pages[page_ix].get_image(
184199
scale=self.images_scale, cropbox=expanded_bbox
185200
)
201+
202+
# Allow for images being embedded without the page backend or page images
203+
if cropped_image is None and isinstance(element, PictureItem):
204+
embedded_im = element.get_image(conv_res.document)
205+
if embedded_im is not None:
206+
return ItemAndImageEnrichmentElement(item=element, image=embedded_im)
207+
else:
208+
return None
209+
210+
# Return the proper cropped image
186211
return ItemAndImageEnrichmentElement(item=element, image=cropped_image)

docling/pipeline/asr_pipeline.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -208,25 +208,13 @@ def __init__(self, pipeline_options: AsrPipelineOptions):
208208

209209
self.pipeline_options: AsrPipelineOptions = pipeline_options
210210

211-
artifacts_path: Optional[Path] = None
212-
if pipeline_options.artifacts_path is not None:
213-
artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
214-
elif settings.artifacts_path is not None:
215-
artifacts_path = Path(settings.artifacts_path).expanduser()
216-
217-
if artifacts_path is not None and not artifacts_path.is_dir():
218-
raise RuntimeError(
219-
f"The value of {artifacts_path=} is not valid. "
220-
"When defined, it must point to a folder containing all models required by the pipeline."
221-
)
222-
223211
if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
224212
asr_options: InlineAsrNativeWhisperOptions = (
225213
self.pipeline_options.asr_options
226214
)
227215
self._model = _NativeWhisperModel(
228216
enabled=True, # must be always enabled for this pipeline to make sense.
229-
artifacts_path=artifacts_path,
217+
artifacts_path=self.artifacts_path,
230218
accelerator_options=pipeline_options.accelerator_options,
231219
asr_options=asr_options,
232220
)

docling/pipeline/base_extraction_pipeline.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,33 @@
11
import logging
22
from abc import ABC, abstractmethod
3+
from pathlib import Path
34
from typing import Optional
45

56
from docling.datamodel.base_models import ConversionStatus, ErrorItem
67
from docling.datamodel.document import InputDocument
78
from docling.datamodel.extraction import ExtractionResult, ExtractionTemplateType
8-
from docling.datamodel.pipeline_options import BaseOptions
9+
from docling.datamodel.pipeline_options import BaseOptions, PipelineOptions
10+
from docling.datamodel.settings import settings
911

1012
_log = logging.getLogger(__name__)
1113

1214

1315
class BaseExtractionPipeline(ABC):
14-
def __init__(self, pipeline_options: BaseOptions):
16+
def __init__(self, pipeline_options: PipelineOptions):
1517
self.pipeline_options = pipeline_options
1618

19+
self.artifacts_path: Optional[Path] = None
20+
if pipeline_options.artifacts_path is not None:
21+
self.artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
22+
elif settings.artifacts_path is not None:
23+
self.artifacts_path = Path(settings.artifacts_path).expanduser()
24+
25+
if self.artifacts_path is not None and not self.artifacts_path.is_dir():
26+
raise RuntimeError(
27+
f"The value of {self.artifacts_path=} is not valid. "
28+
"When defined, it must point to a folder containing all models required by the pipeline."
29+
)
30+
1731
def execute(
1832
self,
1933
in_doc: InputDocument,
@@ -54,5 +68,5 @@ def _determine_status(self, ext_res: ExtractionResult) -> ConversionStatus:
5468

5569
@classmethod
5670
@abstractmethod
57-
def get_default_options(cls) -> BaseOptions:
71+
def get_default_options(cls) -> PipelineOptions:
5872
pass

docling/pipeline/base_pipeline.py

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import traceback
55
from abc import ABC, abstractmethod
66
from collections.abc import Iterable
7-
from typing import Any, Callable, List
7+
from pathlib import Path
8+
from typing import Any, Callable, List, Optional
89

910
from docling_core.types.doc import NodeItem
1011

@@ -20,9 +21,19 @@
2021
Page,
2122
)
2223
from docling.datamodel.document import ConversionResult, InputDocument
23-
from docling.datamodel.pipeline_options import PdfPipelineOptions, PipelineOptions
24+
from docling.datamodel.pipeline_options import (
25+
ConvertPipelineOptions,
26+
PdfPipelineOptions,
27+
PipelineOptions,
28+
)
2429
from docling.datamodel.settings import settings
2530
from docling.models.base_model import GenericEnrichmentModel
31+
from docling.models.document_picture_classifier import (
32+
DocumentPictureClassifier,
33+
DocumentPictureClassifierOptions,
34+
)
35+
from docling.models.factories import get_picture_description_factory
36+
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
2637
from docling.utils.profiling import ProfilingScope, TimeRecorder
2738
from docling.utils.utils import chunkify
2839

@@ -36,6 +47,18 @@ def __init__(self, pipeline_options: PipelineOptions):
3647
self.build_pipe: List[Callable] = []
3748
self.enrichment_pipe: List[GenericEnrichmentModel[Any]] = []
3849

50+
self.artifacts_path: Optional[Path] = None
51+
if pipeline_options.artifacts_path is not None:
52+
self.artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
53+
elif settings.artifacts_path is not None:
54+
self.artifacts_path = Path(settings.artifacts_path).expanduser()
55+
56+
if self.artifacts_path is not None and not self.artifacts_path.is_dir():
57+
raise RuntimeError(
58+
f"The value of {self.artifacts_path=} is not valid. "
59+
"When defined, it must point to a folder containing all models required by the pipeline."
60+
)
61+
3962
def execute(self, in_doc: InputDocument, raises_on_error: bool) -> ConversionResult:
4063
conv_res = ConversionResult(input=in_doc)
4164

@@ -108,15 +131,58 @@ def get_default_options(cls) -> PipelineOptions:
108131
def is_backend_supported(cls, backend: AbstractDocumentBackend):
109132
pass
110133

111-
# def _apply_on_elements(self, element_batch: Iterable[NodeItem]) -> Iterable[Any]:
112-
# for model in self.build_pipe:
113-
# element_batch = model(element_batch)
114-
#
115-
# yield from element_batch
116134

135+
class ConvertPipeline(BasePipeline):
136+
def __init__(self, pipeline_options: ConvertPipelineOptions):
137+
super().__init__(pipeline_options)
138+
self.pipeline_options: ConvertPipelineOptions
117139

118-
class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
119-
def __init__(self, pipeline_options: PipelineOptions):
140+
# ------ Common enrichment models working on all backends
141+
142+
# Picture description model
143+
if (
144+
picture_description_model := self._get_picture_description_model(
145+
artifacts_path=self.artifacts_path
146+
)
147+
) is None:
148+
raise RuntimeError(
149+
f"The specified picture description kind is not supported: {pipeline_options.picture_description_options.kind}."
150+
)
151+
152+
self.enrichment_pipe = [
153+
# Document Picture Classifier
154+
DocumentPictureClassifier(
155+
enabled=pipeline_options.do_picture_classification,
156+
artifacts_path=self.artifacts_path,
157+
options=DocumentPictureClassifierOptions(),
158+
accelerator_options=pipeline_options.accelerator_options,
159+
),
160+
# Document Picture description
161+
picture_description_model,
162+
]
163+
164+
def _get_picture_description_model(
165+
self, artifacts_path: Optional[Path] = None
166+
) -> Optional[PictureDescriptionBaseModel]:
167+
factory = get_picture_description_factory(
168+
allow_external_plugins=self.pipeline_options.allow_external_plugins
169+
)
170+
return factory.create_instance(
171+
options=self.pipeline_options.picture_description_options,
172+
enabled=self.pipeline_options.do_picture_description,
173+
enable_remote_services=self.pipeline_options.enable_remote_services,
174+
artifacts_path=artifacts_path,
175+
accelerator_options=self.pipeline_options.accelerator_options,
176+
)
177+
178+
@classmethod
179+
@abstractmethod
180+
def get_default_options(cls) -> ConvertPipelineOptions:
181+
pass
182+
183+
184+
class PaginatedPipeline(ConvertPipeline): # TODO this is a bad name.
185+
def __init__(self, pipeline_options: ConvertPipelineOptions):
120186
super().__init__(pipeline_options)
121187
self.keep_backend = False
122188

docling/pipeline/extraction_vlm_pipeline.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import inspect
22
import json
33
import logging
4-
from pathlib import Path
54
from typing import Optional
65

76
from PIL.Image import Image
@@ -16,7 +15,10 @@
1615
ExtractionResult,
1716
ExtractionTemplateType,
1817
)
19-
from docling.datamodel.pipeline_options import BaseOptions, VlmExtractionPipelineOptions
18+
from docling.datamodel.pipeline_options import (
19+
PipelineOptions,
20+
VlmExtractionPipelineOptions,
21+
)
2022
from docling.datamodel.settings import settings
2123
from docling.models.vlm_models_inline.nuextract_transformers_model import (
2224
NuExtractTransformersModel,
@@ -35,22 +37,10 @@ def __init__(self, pipeline_options: VlmExtractionPipelineOptions):
3537
self.accelerator_options = pipeline_options.accelerator_options
3638
self.pipeline_options: VlmExtractionPipelineOptions
3739

38-
artifacts_path: Optional[Path] = None
39-
if pipeline_options.artifacts_path is not None:
40-
artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
41-
elif settings.artifacts_path is not None:
42-
artifacts_path = Path(settings.artifacts_path).expanduser()
43-
44-
if artifacts_path is not None and not artifacts_path.is_dir():
45-
raise RuntimeError(
46-
f"The value of {artifacts_path=} is not valid. "
47-
"When defined, it must point to a folder containing all models required by the pipeline."
48-
)
49-
5040
# Create VLM model instance
5141
self.vlm_model = NuExtractTransformersModel(
5242
enabled=True,
53-
artifacts_path=artifacts_path, # Will download automatically
43+
artifacts_path=self.artifacts_path, # Will download automatically
5444
accelerator_options=self.accelerator_options,
5545
vlm_options=pipeline_options.vlm_options,
5646
)
@@ -203,5 +193,5 @@ class ExtractionTemplateFactory(ModelFactory[template]): # type: ignore
203193
raise ValueError(f"Unsupported template type: {type(template)}")
204194

205195
@classmethod
206-
def get_default_options(cls) -> BaseOptions:
196+
def get_default_options(cls) -> PipelineOptions:
207197
return VlmExtractionPipelineOptions()

0 commit comments

Comments
 (0)