Skip to content

Commit 48cedcd

Browse files
author
Maksym Lysak
committed
Example on how to apply to Docling Document OCR as a post-processing with "nanonets-ocr2-3b" via LM Studio
Signed-off-by: Maksym Lysak <[email protected]>
1 parent 3e6da2c commit 48cedcd

File tree

1 file changed

+344
-0
lines changed

1 file changed

+344
-0
lines changed
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterable
4+
from concurrent.futures import ThreadPoolExecutor
5+
from pathlib import Path
6+
from typing import Optional, Union, Any
7+
8+
from PIL import Image
9+
10+
from PIL.ImageOps import crop
11+
from docling_core.types.doc import (
12+
DoclingDocument,
13+
NodeItem,
14+
PageItem,
15+
TextItem,
16+
)
17+
from docling_core.types.doc import ImageRefMode
18+
from docling.utils.profiling import ProfilingScope, TimeRecorder
19+
from docling.utils.utils import chunkify
20+
from docling_core.types.doc.document import ContentLayer, DocItem, GraphCell, KeyValueItem, PictureItem, TableCell, RichTableCell, TableItem
21+
from pydantic import BaseModel, ConfigDict
22+
from docling.backend.json.docling_json_backend import DoclingJSONBackend
23+
from docling.datamodel.accelerator_options import AcceleratorOptions
24+
from docling.datamodel.base_models import InputFormat, ItemAndImageEnrichmentElement
25+
from docling.datamodel.document import ConversionResult
26+
from docling.datamodel.pipeline_options import ConvertPipelineOptions, PictureDescriptionApiOptions, PdfPipelineOptions
27+
from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption
28+
from docling.pipeline.simple_pipeline import SimplePipeline
29+
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
30+
from docling.exceptions import OperationNotAllowed
31+
from docling.models.base_model import BaseModelWithOptions, GenericEnrichmentModel
32+
from docling.utils.api_image_request import api_image_request
33+
34+
# Example on how to apply to Docling Document OCR as a post-processing with "nanonets-ocr2-3b" via LM Studio
35+
# Requires LM Studio running inference server with "nanonets-ocr2-3b" model pre-loaded
36+
LM_STUDIO_URL = "http://localhost:1234/v1/chat/completions"
37+
LM_STUDIO_MODEL = "nanonets-ocr2-3b"
38+
DEFAULT_PROMPT = "Extract the text from the above document as if you were reading it naturally."
39+
40+
PDF_DOC = "tests/data/pdf/2305.03393v1-pg9.pdf"
41+
# PDF_DOC = "tests/data/pdf/amt_handbook_sample.pdf"
42+
# PDF_DOC = "tests/data/pdf/2206.01062.pdf"
43+
JSON_DOC = "scratch/test_doc.json"
44+
POST_PROCESSED_JSON_DOC = "scratch/test_doc_ocr.json"
45+
46+
class OcrEnrichmentElement(BaseModel):
47+
model_config = ConfigDict(arbitrary_types_allowed=True)
48+
49+
item: Union[DocItem, TableCell, RichTableCell, GraphCell]
50+
image: Image # TODO maybe needs to be an array of images for multi-provenance things.
51+
52+
53+
class OcrEnrichmentPipelineOptions(ConvertPipelineOptions):
54+
api_options: PictureDescriptionApiOptions
55+
56+
class OcrEnrichmentPipeline(SimplePipeline):
57+
def __init__(self, pipeline_options: OcrEnrichmentPipelineOptions):
58+
super().__init__(pipeline_options)
59+
self.pipeline_options: OcrEnrichmentPipelineOptions
60+
61+
self.enrichment_pipe = [
62+
OcrApiEnrichmentModel(
63+
enabled=True,
64+
enable_remote_services=True,
65+
artifacts_path=None,
66+
options=self.pipeline_options.api_options,
67+
accelerator_options=AcceleratorOptions(),
68+
)
69+
]
70+
71+
@classmethod
72+
def get_default_options(cls) -> OcrEnrichmentPipelineOptions:
73+
return OcrEnrichmentPipelineOptions()
74+
75+
def _enrich_document(self, conv_res: ConversionResult) -> ConversionResult:
76+
def _prepare_elements(
77+
conv_res: ConversionResult, model: GenericEnrichmentModel[Any]
78+
) -> Iterable[NodeItem]:
79+
for doc_element, _level in conv_res.document.iterate_items(
80+
traverse_pictures=True,
81+
included_content_layers={
82+
ContentLayer.BODY,
83+
ContentLayer.FURNITURE,
84+
},
85+
): # With all content layers, with traverse_pictures=True
86+
prepared_elements = model.prepare_element( # make this one yield multiple items.
87+
conv_res=conv_res, element=doc_element
88+
)
89+
if prepared_elements is not None:
90+
yield prepared_elements
91+
92+
with TimeRecorder(conv_res, "doc_enrich", scope=ProfilingScope.DOCUMENT):
93+
for model in self.enrichment_pipe:
94+
for element_batch in chunkify(
95+
_prepare_elements(conv_res, model),
96+
model.elements_batch_size,
97+
):
98+
for element in model(
99+
doc=conv_res.document, element_batch=element_batch
100+
): # Must exhaust!
101+
pass
102+
return conv_res
103+
104+
class OcrApiEnrichmentModel(GenericEnrichmentModel[OcrEnrichmentElement], BaseModelWithOptions):
105+
expansion_factor: float = 0.001
106+
107+
def prepare_element(
108+
self, conv_res: ConversionResult, element: NodeItem
109+
) -> Optional[list[OcrEnrichmentElement]]:
110+
if not self.is_processable(doc=conv_res.document, element=element):
111+
return None
112+
113+
# allowed = (DocItem, TableCell, RichTableCell, GraphCell)
114+
allowed = (DocItem, TableItem, GraphCell)
115+
assert isinstance(element, allowed) # too strict, could be DocItem, TableCell, RichTableCell, GraphCell
116+
117+
118+
if isinstance(element, KeyValueItem):
119+
# Yield from the graphCells inside here.
120+
result = []
121+
for c in element.graph.cells:
122+
element_prov = c.prov # Key / Value have only one provenance!
123+
bbox = element_prov.bbox
124+
page_ix = element_prov.page_no
125+
bbox = bbox.scale_to_size(old_size=conv_res.document.pages[page_ix].size, new_size=conv_res.document.pages[page_ix].image.size)
126+
expanded_bbox = bbox.expand_by_scale(
127+
x_scale=self.expansion_factor, y_scale=self.expansion_factor
128+
).to_top_left_origin(page_height=conv_res.document.pages[page_ix].image.size.height)
129+
130+
good_bbox = True
131+
if expanded_bbox.l > expanded_bbox.r or expanded_bbox.t > expanded_bbox.b:
132+
good_bbox = False
133+
134+
if good_bbox:
135+
cropped_image = conv_res.document.pages[page_ix].image.pil_image.crop(expanded_bbox.as_tuple())
136+
# cropped_image.show()
137+
result.append(OcrEnrichmentElement(item=c, image=cropped_image))
138+
return result
139+
elif isinstance(element, TableItem):
140+
element_prov = element.prov[0]
141+
page_ix = element_prov.page_no
142+
result = []
143+
for i, row in enumerate(element.data.grid):
144+
for j, cell in enumerate(row):
145+
if hasattr(cell, "bbox"):
146+
if cell.bbox:
147+
bbox = cell.bbox
148+
bbox = bbox.scale_to_size(old_size=conv_res.document.pages[page_ix].size, new_size=conv_res.document.pages[page_ix].image.size)
149+
expanded_bbox = bbox.expand_by_scale(
150+
x_scale=self.expansion_factor, y_scale=self.expansion_factor
151+
).to_top_left_origin(page_height=conv_res.document.pages[page_ix].image.size.height)
152+
153+
good_bbox = True
154+
if expanded_bbox.l > expanded_bbox.r or expanded_bbox.t > expanded_bbox.b:
155+
good_bbox = False
156+
157+
if good_bbox:
158+
cropped_image = conv_res.document.pages[page_ix].image.pil_image.crop(expanded_bbox.as_tuple())
159+
# cropped_image.show()
160+
result.append(OcrEnrichmentElement(item=cell, image=cropped_image))
161+
return result
162+
else:
163+
# Crop the image form the page
164+
element_prov = element.prov[0] # TODO: Not all items have prov
165+
bbox = element_prov.bbox
166+
167+
page_ix = element_prov.page_no
168+
bbox = bbox.scale_to_size(old_size=conv_res.document.pages[page_ix].size, new_size=conv_res.document.pages[page_ix].image.size)
169+
expanded_bbox = bbox.expand_by_scale(
170+
x_scale=self.expansion_factor, y_scale=self.expansion_factor
171+
).to_top_left_origin(page_height=conv_res.document.pages[page_ix].image.size.height)
172+
173+
good_bbox = True
174+
if expanded_bbox.l > expanded_bbox.r or expanded_bbox.t > expanded_bbox.b:
175+
good_bbox = False
176+
177+
if good_bbox:
178+
cropped_image = conv_res.document.pages[page_ix].image.pil_image.crop(expanded_bbox.as_tuple())
179+
cropped_image.show()
180+
# Return the proper cropped image
181+
return [OcrEnrichmentElement(item=element, image=cropped_image)]
182+
else:
183+
return []
184+
185+
@classmethod
186+
def get_options_type(cls) -> type[PictureDescriptionApiOptions]:
187+
return PictureDescriptionApiOptions
188+
189+
def __init__(
190+
self,
191+
*,
192+
enabled: bool,
193+
enable_remote_services: bool,
194+
artifacts_path: Optional[Union[Path, str]],
195+
options: PictureDescriptionApiOptions,
196+
accelerator_options: AcceleratorOptions,
197+
):
198+
self.enabled = enabled
199+
self.options = options
200+
self.concurrency = 4
201+
self.expansion_factor = 0.05
202+
self.elements_batch_size = 4
203+
self._accelerator_options = accelerator_options
204+
self._artifacts_path = (
205+
Path(artifacts_path) if isinstance(artifacts_path, str) else artifacts_path
206+
)
207+
208+
if self.enabled and not enable_remote_services:
209+
raise OperationNotAllowed(
210+
"Enable remote services by setting pipeline_options.enable_remote_services=True."
211+
)
212+
213+
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
214+
return self.enabled
215+
216+
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
217+
def _api_request(image: Image.Image) -> str:
218+
return api_image_request(
219+
image=image,
220+
prompt=self.options.prompt,
221+
url=self.options.url,
222+
timeout=self.options.timeout,
223+
headers=self.options.headers,
224+
**self.options.params,
225+
)
226+
227+
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
228+
yield from executor.map(_api_request, images)
229+
230+
def __call__(
231+
self,
232+
doc: DoclingDocument,
233+
element_batch: Iterable[ItemAndImageEnrichmentElement],
234+
) -> Iterable[NodeItem]:
235+
if not self.enabled:
236+
for element in element_batch:
237+
yield element.item
238+
return
239+
240+
elements: list[TextItem] = []
241+
images: list[Image.Image] = []
242+
for element_stack in element_batch:
243+
for element in element_stack:
244+
allowed = (DocItem, TableCell, RichTableCell, GraphCell)
245+
assert isinstance(element.item, allowed)
246+
elements.append(element.item)
247+
images.append(element.image)
248+
249+
if not images:
250+
return
251+
252+
outputs = list(self._annotate_images(images))
253+
254+
for item, output in zip(elements, outputs):
255+
# Sometimes model can return html tags, which are not strictly needed in our, so it's better to clean them
256+
def clean_html_tags(text):
257+
for tag in ["<table>", "<tr>", "<td>", "<strong>", "</table>", "</tr>", "</td>", "</strong>", "<th>", "</th>", "<tbody>", "<tbody>", "<thead>", "</thead>"]:
258+
text = text.replace(tag, "")
259+
return text
260+
output = clean_html_tags(output)
261+
262+
if isinstance(item, (TextItem)):
263+
print("OLD TEXT: {}".format(item.text))
264+
print("NEW TEXT: {}".format(output))
265+
266+
# Re-populate text
267+
if isinstance(item, (TextItem, GraphCell)):
268+
item.text = output
269+
item.orig = output
270+
elif isinstance(item, (TableCell, RichTableCell)):
271+
item.text = output
272+
elif isinstance(item, PictureItem):
273+
pass
274+
else:
275+
raise ValueError(f"Unknown item type: {type(item)}")
276+
277+
# Take care of charspans for relevant types
278+
if isinstance(item, GraphCell):
279+
item.prov.charspan = [0, len(output)]
280+
elif isinstance(item, TextItem):
281+
item.prov[0].charspan = [0, len(output)]
282+
283+
yield item
284+
285+
286+
def main() -> None:
287+
# TODO: Properly process cases for the items which have more than one provenance
288+
289+
# Let's prepare a Docling document json with embedded page images
290+
pipeline_options = PdfPipelineOptions()
291+
pipeline_options.generate_page_images = True
292+
pipeline_options.generate_picture_images = True
293+
pipeline_options.images_scale = 4.0
294+
295+
doc_converter = DocumentConverter( # all of the below is optional, has internal defaults.
296+
allowed_formats=[InputFormat.PDF],
297+
format_options={InputFormat.PDF: PdfFormatOption(pipeline_cls=StandardPdfPipeline, pipeline_options=pipeline_options)}
298+
)
299+
print("Converting PDF to get a Docling document json with embedded page images...")
300+
conv_result = doc_converter.convert(PDF_DOC)
301+
# conv_result.document.save_as_json(filename = JSON_DOC, image_mode = ImageRefMode.EMBEDDED)
302+
conv_result.document.save_as_json(filename = JSON_DOC, image_mode = ImageRefMode.REFERENCED)
303+
304+
md1 = conv_result.document.export_to_markdown()
305+
print("*** ORIGINAL MARKDOWN ***")
306+
print(md1)
307+
308+
print("Post-process all bounding boxes with OCR")
309+
# Post-Process OCR on top of existing Docling document:
310+
pipeline_options = OcrEnrichmentPipelineOptions(
311+
api_options=PictureDescriptionApiOptions(
312+
url=LM_STUDIO_URL,
313+
prompt=DEFAULT_PROMPT,
314+
provenance="lm-studio-ocr",
315+
batch_size=4,
316+
concurrency=2,
317+
scale=2.0,
318+
params={"model": LM_STUDIO_MODEL},
319+
)
320+
)
321+
322+
doc_converter = DocumentConverter(
323+
format_options={
324+
InputFormat.JSON_DOCLING: FormatOption(
325+
pipeline_cls=OcrEnrichmentPipeline,
326+
pipeline_options=pipeline_options,
327+
backend=DoclingJSONBackend
328+
)
329+
}
330+
)
331+
result = doc_converter.convert(JSON_DOC)
332+
result.document.pages[1].image.pil_image.show()
333+
result.document.save_as_json(POST_PROCESSED_JSON_DOC)
334+
md = result.document.export_to_markdown()
335+
print("*** MARKDOWN ***")
336+
print(md)
337+
# print("*** KV ITEMS ***")
338+
# for kv_item in result.document.key_value_items:
339+
# for kv_item_cell in kv_item.graph.cells:
340+
# print("{} - {}".format(kv_item_cell.label, kv_item_cell.text))
341+
342+
343+
if __name__ == "__main__":
344+
main()

0 commit comments

Comments
 (0)