Skip to content

Commit f9d67fe

Browse files
author
Maksym Lysak
committed
Added support of elements with multiple provenances
Signed-off-by: Maksym Lysak <[email protected]>
1 parent 48cedcd commit f9d67fe

File tree

1 file changed

+163
-64
lines changed

1 file changed

+163
-64
lines changed

docs/examples/post_process_ocr_with_vlm.py

Lines changed: 163 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,56 +3,73 @@
33
from collections.abc import Iterable
44
from concurrent.futures import ThreadPoolExecutor
55
from pathlib import Path
6-
from typing import Optional, Union, Any
6+
from typing import Any, Optional, Union
77

8-
from PIL import Image
9-
10-
from PIL.ImageOps import crop
118
from docling_core.types.doc import (
129
DoclingDocument,
10+
ImageRefMode,
1311
NodeItem,
1412
PageItem,
1513
TextItem,
1614
)
17-
from docling_core.types.doc import ImageRefMode
18-
from docling.utils.profiling import ProfilingScope, TimeRecorder
19-
from docling.utils.utils import chunkify
20-
from docling_core.types.doc.document import ContentLayer, DocItem, GraphCell, KeyValueItem, PictureItem, TableCell, RichTableCell, TableItem
15+
from docling_core.types.doc.document import (
16+
ContentLayer,
17+
DocItem,
18+
GraphCell,
19+
KeyValueItem,
20+
PictureItem,
21+
RichTableCell,
22+
TableCell,
23+
TableItem,
24+
)
25+
from PIL import Image
26+
from PIL.ImageOps import crop
2127
from pydantic import BaseModel, ConfigDict
28+
2229
from docling.backend.json.docling_json_backend import DoclingJSONBackend
2330
from docling.datamodel.accelerator_options import AcceleratorOptions
2431
from docling.datamodel.base_models import InputFormat, ItemAndImageEnrichmentElement
2532
from docling.datamodel.document import ConversionResult
26-
from docling.datamodel.pipeline_options import ConvertPipelineOptions, PictureDescriptionApiOptions, PdfPipelineOptions
33+
from docling.datamodel.pipeline_options import (
34+
ConvertPipelineOptions,
35+
PdfPipelineOptions,
36+
PictureDescriptionApiOptions,
37+
)
2738
from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption
28-
from docling.pipeline.simple_pipeline import SimplePipeline
29-
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
3039
from docling.exceptions import OperationNotAllowed
3140
from docling.models.base_model import BaseModelWithOptions, GenericEnrichmentModel
41+
from docling.pipeline.simple_pipeline import SimplePipeline
42+
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
3243
from docling.utils.api_image_request import api_image_request
44+
from docling.utils.profiling import ProfilingScope, TimeRecorder
45+
from docling.utils.utils import chunkify
3346

3447
# Example on how to apply to Docling Document OCR as a post-processing with "nanonets-ocr2-3b" via LM Studio
3548
# Requires LM Studio running inference server with "nanonets-ocr2-3b" model pre-loaded
3649
LM_STUDIO_URL = "http://localhost:1234/v1/chat/completions"
3750
LM_STUDIO_MODEL = "nanonets-ocr2-3b"
38-
DEFAULT_PROMPT = "Extract the text from the above document as if you were reading it naturally."
51+
DEFAULT_PROMPT = (
52+
"Extract the text from the above document as if you were reading it naturally."
53+
)
3954

4055
PDF_DOC = "tests/data/pdf/2305.03393v1-pg9.pdf"
41-
# PDF_DOC = "tests/data/pdf/amt_handbook_sample.pdf"
42-
# PDF_DOC = "tests/data/pdf/2206.01062.pdf"
4356
JSON_DOC = "scratch/test_doc.json"
4457
POST_PROCESSED_JSON_DOC = "scratch/test_doc_ocr.json"
4558

59+
4660
class OcrEnrichmentElement(BaseModel):
4761
model_config = ConfigDict(arbitrary_types_allowed=True)
4862

4963
item: Union[DocItem, TableCell, RichTableCell, GraphCell]
50-
image: Image # TODO maybe needs to be an array of images for multi-provenance things.
64+
image: list[
65+
Image.Image
66+
] # TODO maybe needs to be an array of images for multi-provenance things.
5167

5268

5369
class OcrEnrichmentPipelineOptions(ConvertPipelineOptions):
5470
api_options: PictureDescriptionApiOptions
5571

72+
5673
class OcrEnrichmentPipeline(SimplePipeline):
5774
def __init__(self, pipeline_options: OcrEnrichmentPipelineOptions):
5875
super().__init__(pipeline_options)
@@ -82,9 +99,11 @@ def _prepare_elements(
8299
ContentLayer.BODY,
83100
ContentLayer.FURNITURE,
84101
},
85-
): # With all content layers, with traverse_pictures=True
86-
prepared_elements = model.prepare_element( # make this one yield multiple items.
87-
conv_res=conv_res, element=doc_element
102+
): # With all content layers, with traverse_pictures=True
103+
prepared_elements = (
104+
model.prepare_element( # make this one yield multiple items.
105+
conv_res=conv_res, element=doc_element
106+
)
88107
)
89108
if prepared_elements is not None:
90109
yield prepared_elements
@@ -101,7 +120,10 @@ def _prepare_elements(
101120
pass
102121
return conv_res
103122

104-
class OcrApiEnrichmentModel(GenericEnrichmentModel[OcrEnrichmentElement], BaseModelWithOptions):
123+
124+
class OcrApiEnrichmentModel(
125+
GenericEnrichmentModel[OcrEnrichmentElement], BaseModelWithOptions
126+
):
105127
expansion_factor: float = 0.001
106128

107129
def prepare_element(
@@ -112,8 +134,9 @@ def prepare_element(
112134

113135
# allowed = (DocItem, TableCell, RichTableCell, GraphCell)
114136
allowed = (DocItem, TableItem, GraphCell)
115-
assert isinstance(element, allowed) # too strict, could be DocItem, TableCell, RichTableCell, GraphCell
116-
137+
assert isinstance(
138+
element, allowed
139+
) # too strict, could be DocItem, TableCell, RichTableCell, GraphCell
117140

118141
if isinstance(element, KeyValueItem):
119142
# Yield from the graphCells inside here.
@@ -122,19 +145,29 @@ def prepare_element(
122145
element_prov = c.prov # Key / Value have only one provenance!
123146
bbox = element_prov.bbox
124147
page_ix = element_prov.page_no
125-
bbox = bbox.scale_to_size(old_size=conv_res.document.pages[page_ix].size, new_size=conv_res.document.pages[page_ix].image.size)
148+
bbox = bbox.scale_to_size(
149+
old_size=conv_res.document.pages[page_ix].size,
150+
new_size=conv_res.document.pages[page_ix].image.size,
151+
)
126152
expanded_bbox = bbox.expand_by_scale(
127153
x_scale=self.expansion_factor, y_scale=self.expansion_factor
128-
).to_top_left_origin(page_height=conv_res.document.pages[page_ix].image.size.height)
154+
).to_top_left_origin(
155+
page_height=conv_res.document.pages[page_ix].image.size.height
156+
)
129157

130158
good_bbox = True
131-
if expanded_bbox.l > expanded_bbox.r or expanded_bbox.t > expanded_bbox.b:
159+
if (
160+
expanded_bbox.l > expanded_bbox.r
161+
or expanded_bbox.t > expanded_bbox.b
162+
):
132163
good_bbox = False
133164

134165
if good_bbox:
135-
cropped_image = conv_res.document.pages[page_ix].image.pil_image.crop(expanded_bbox.as_tuple())
166+
cropped_image = conv_res.document.pages[
167+
page_ix
168+
].image.pil_image.crop(expanded_bbox.as_tuple())
136169
# cropped_image.show()
137-
result.append(OcrEnrichmentElement(item=c, image=cropped_image))
170+
result.append(OcrEnrichmentElement(item=c, image=[cropped_image]))
138171
return result
139172
elif isinstance(element, TableItem):
140173
element_prov = element.prov[0]
@@ -145,43 +178,75 @@ def prepare_element(
145178
if hasattr(cell, "bbox"):
146179
if cell.bbox:
147180
bbox = cell.bbox
148-
bbox = bbox.scale_to_size(old_size=conv_res.document.pages[page_ix].size, new_size=conv_res.document.pages[page_ix].image.size)
181+
bbox = bbox.scale_to_size(
182+
old_size=conv_res.document.pages[page_ix].size,
183+
new_size=conv_res.document.pages[page_ix].image.size,
184+
)
149185
expanded_bbox = bbox.expand_by_scale(
150-
x_scale=self.expansion_factor, y_scale=self.expansion_factor
151-
).to_top_left_origin(page_height=conv_res.document.pages[page_ix].image.size.height)
186+
x_scale=self.expansion_factor,
187+
y_scale=self.expansion_factor,
188+
).to_top_left_origin(
189+
page_height=conv_res.document.pages[
190+
page_ix
191+
].image.size.height
192+
)
152193

153194
good_bbox = True
154-
if expanded_bbox.l > expanded_bbox.r or expanded_bbox.t > expanded_bbox.b:
195+
if (
196+
expanded_bbox.l > expanded_bbox.r
197+
or expanded_bbox.t > expanded_bbox.b
198+
):
155199
good_bbox = False
156200

157201
if good_bbox:
158-
cropped_image = conv_res.document.pages[page_ix].image.pil_image.crop(expanded_bbox.as_tuple())
202+
cropped_image = conv_res.document.pages[
203+
page_ix
204+
].image.pil_image.crop(expanded_bbox.as_tuple())
159205
# cropped_image.show()
160-
result.append(OcrEnrichmentElement(item=cell, image=cropped_image))
206+
result.append(
207+
OcrEnrichmentElement(
208+
item=cell, image=[cropped_image]
209+
)
210+
)
161211
return result
162212
else:
213+
multiple_crops = []
163214
# Crop the image form the page
164-
element_prov = element.prov[0] # TODO: Not all items have prov
165-
bbox = element_prov.bbox
215+
for element_prov in element.prov:
216+
# element_prov = element.prov[0] # TODO: Not all items have prov
217+
bbox = element_prov.bbox
166218

167-
page_ix = element_prov.page_no
168-
bbox = bbox.scale_to_size(old_size=conv_res.document.pages[page_ix].size, new_size=conv_res.document.pages[page_ix].image.size)
169-
expanded_bbox = bbox.expand_by_scale(
170-
x_scale=self.expansion_factor, y_scale=self.expansion_factor
171-
).to_top_left_origin(page_height=conv_res.document.pages[page_ix].image.size.height)
172-
173-
good_bbox = True
174-
if expanded_bbox.l > expanded_bbox.r or expanded_bbox.t > expanded_bbox.b:
175-
good_bbox = False
176-
177-
if good_bbox:
178-
cropped_image = conv_res.document.pages[page_ix].image.pil_image.crop(expanded_bbox.as_tuple())
179-
cropped_image.show()
180-
# Return the proper cropped image
181-
return [OcrEnrichmentElement(item=element, image=cropped_image)]
219+
page_ix = element_prov.page_no
220+
bbox = bbox.scale_to_size(
221+
old_size=conv_res.document.pages[page_ix].size,
222+
new_size=conv_res.document.pages[page_ix].image.size,
223+
)
224+
expanded_bbox = bbox.expand_by_scale(
225+
x_scale=self.expansion_factor, y_scale=self.expansion_factor
226+
).to_top_left_origin(
227+
page_height=conv_res.document.pages[page_ix].image.size.height
228+
)
229+
230+
good_bbox = True
231+
if (
232+
expanded_bbox.l > expanded_bbox.r
233+
or expanded_bbox.t > expanded_bbox.b
234+
):
235+
good_bbox = False
236+
237+
if good_bbox:
238+
cropped_image = conv_res.document.pages[
239+
page_ix
240+
].image.pil_image.crop(expanded_bbox.as_tuple())
241+
multiple_crops.append(cropped_image)
242+
# cropped_image.show()
243+
# Return the proper cropped image
244+
multiple_crops
245+
if len(multiple_crops) > 0:
246+
return [OcrEnrichmentElement(item=element, image=multiple_crops)]
182247
else:
183248
return []
184-
249+
185250
@classmethod
186251
def get_options_type(cls) -> type[PictureDescriptionApiOptions]:
187252
return PictureDescriptionApiOptions
@@ -239,34 +304,60 @@ def __call__(
239304

240305
elements: list[TextItem] = []
241306
images: list[Image.Image] = []
307+
img_ind_per_element: list[int] = []
308+
242309
for element_stack in element_batch:
243310
for element in element_stack:
244311
allowed = (DocItem, TableCell, RichTableCell, GraphCell)
245312
assert isinstance(element.item, allowed)
246-
elements.append(element.item)
247-
images.append(element.image)
313+
for ind, img in enumerate(element.image):
314+
elements.append(element.item)
315+
images.append(img)
316+
# images.append(element.image)
317+
img_ind_per_element.append(ind)
248318

249319
if not images:
250320
return
251321

252322
outputs = list(self._annotate_images(images))
253323

254-
for item, output in zip(elements, outputs):
324+
for item, output, img_ind in zip(elements, outputs, img_ind_per_element):
255325
# Sometimes model can return html tags, which are not strictly needed in our, so it's better to clean them
256326
def clean_html_tags(text):
257-
for tag in ["<table>", "<tr>", "<td>", "<strong>", "</table>", "</tr>", "</td>", "</strong>", "<th>", "</th>", "<tbody>", "<tbody>", "<thead>", "</thead>"]:
327+
for tag in [
328+
"<table>",
329+
"<tr>",
330+
"<td>",
331+
"<strong>",
332+
"</table>",
333+
"</tr>",
334+
"</td>",
335+
"</strong>",
336+
"<th>",
337+
"</th>",
338+
"<tbody>",
339+
"<tbody>",
340+
"<thead>",
341+
"</thead>",
342+
]:
258343
text = text.replace(tag, "")
259344
return text
345+
260346
output = clean_html_tags(output)
261347

262348
if isinstance(item, (TextItem)):
263-
print("OLD TEXT: {}".format(item.text))
264-
print("NEW TEXT: {}".format(output))
349+
print(f"OLD TEXT: {item.text}")
350+
print(f"NEW TEXT: {output}")
265351

266352
# Re-populate text
267353
if isinstance(item, (TextItem, GraphCell)):
268-
item.text = output
269-
item.orig = output
354+
if img_ind > 0:
355+
# Concat texts across several provenances
356+
item.text += " " + output
357+
item.orig += " " + output
358+
else:
359+
item.text = output
360+
item.orig = output
270361
elif isinstance(item, (TableCell, RichTableCell)):
271362
item.text = output
272363
elif isinstance(item, PictureItem):
@@ -292,14 +383,22 @@ def main() -> None:
292383
pipeline_options.generate_picture_images = True
293384
pipeline_options.images_scale = 4.0
294385

295-
doc_converter = DocumentConverter( # all of the below is optional, has internal defaults.
296-
allowed_formats=[InputFormat.PDF],
297-
format_options={InputFormat.PDF: PdfFormatOption(pipeline_cls=StandardPdfPipeline, pipeline_options=pipeline_options)}
386+
doc_converter = (
387+
DocumentConverter( # all of the below is optional, has internal defaults.
388+
allowed_formats=[InputFormat.PDF],
389+
format_options={
390+
InputFormat.PDF: PdfFormatOption(
391+
pipeline_cls=StandardPdfPipeline, pipeline_options=pipeline_options
392+
)
393+
},
394+
)
298395
)
396+
299397
print("Converting PDF to get a Docling document json with embedded page images...")
300398
conv_result = doc_converter.convert(PDF_DOC)
301-
# conv_result.document.save_as_json(filename = JSON_DOC, image_mode = ImageRefMode.EMBEDDED)
302-
conv_result.document.save_as_json(filename = JSON_DOC, image_mode = ImageRefMode.REFERENCED)
399+
conv_result.document.save_as_json(
400+
filename=JSON_DOC, image_mode=ImageRefMode.EMBEDDED
401+
)
303402

304403
md1 = conv_result.document.export_to_markdown()
305404
print("*** ORIGINAL MARKDOWN ***")
@@ -324,7 +423,7 @@ def main() -> None:
324423
InputFormat.JSON_DOCLING: FormatOption(
325424
pipeline_cls=OcrEnrichmentPipeline,
326425
pipeline_options=pipeline_options,
327-
backend=DoclingJSONBackend
426+
backend=DoclingJSONBackend,
328427
)
329428
}
330429
)

0 commit comments

Comments
 (0)