Skip to content

Commit 98c60bb

Browse files
authored
perf: fix serialization performance (#249)
* perf: fix serialization performance Signed-off-by: Panos Vagenas <[email protected]> * add missing leading HTML tag Signed-off-by: Panos Vagenas <[email protected]> --------- Signed-off-by: Panos Vagenas <[email protected]>
1 parent 9ac2425 commit 98c60bb

24 files changed

+47248
-133
lines changed

docling_core/experimental/serializer/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,15 @@ def serialize_captions(
234234
...
235235

236236
@abstractmethod
237-
def get_excluded_refs(self, **kwargs) -> list[str]:
237+
def get_excluded_refs(self, **kwargs) -> set[str]:
238238
"""Get references to excluded items."""
239239
...
240240

241+
@abstractmethod
242+
def requires_page_break(self) -> bool:
243+
"""Whether to add page breaks."""
244+
...
245+
241246

242247
class BaseSerializerProvider(ABC):
243248
"""Base class for document serializer providers."""

docling_core/experimental/serializer/common.py

Lines changed: 105 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
#
55

66
"""Define base classes for serialization."""
7+
import re
78
import sys
89
from abc import abstractmethod
9-
from copy import deepcopy
1010
from functools import cached_property
1111
from pathlib import Path
12-
from typing import Any, Optional, Union
12+
from typing import Any, Iterable, Optional, Tuple, Union
1313

1414
from pydantic import AnyUrl, BaseModel, NonNegativeInt, computed_field
1515
from typing_extensions import Self, override
@@ -50,6 +50,49 @@
5050
_DEFAULT_LAYERS = {cl for cl in ContentLayer}
5151

5252

53+
class _PageBreakNode(NodeItem):
54+
"""Page break node."""
55+
56+
prev_page: int
57+
next_page: int
58+
59+
60+
class _PageBreakSerResult(SerializationResult):
61+
"""Page break serialization result."""
62+
63+
node: _PageBreakNode
64+
65+
66+
def _iterate_items(
67+
doc: DoclingDocument,
68+
layers: Optional[set[ContentLayer]],
69+
node: Optional[NodeItem] = None,
70+
traverse_pictures: bool = False,
71+
add_page_breaks: bool = False,
72+
):
73+
prev_page_nr: Optional[int] = None
74+
page_break_i = 0
75+
for item, _ in doc.iterate_items(
76+
root=node,
77+
with_groups=True,
78+
included_content_layers=layers,
79+
traverse_pictures=traverse_pictures,
80+
):
81+
if isinstance(item, DocItem):
82+
if item.prov:
83+
page_no = item.prov[0].page_no
84+
if add_page_breaks and (prev_page_nr is None or page_no > prev_page_nr):
85+
if prev_page_nr is not None: # close previous range
86+
yield _PageBreakNode(
87+
self_ref=f"#/pb/{page_break_i}",
88+
prev_page=prev_page_nr,
89+
next_page=page_no,
90+
)
91+
page_break_i += 1
92+
prev_page_nr = page_no
93+
yield item
94+
95+
5396
def create_ser_result(
5497
*,
5598
text: str = "",
@@ -128,7 +171,7 @@ class Config:
128171

129172
params: CommonParams = CommonParams()
130173

131-
_excluded_refs_cache: dict[str, list[str]] = {}
174+
_excluded_refs_cache: dict[str, set[str]] = {}
132175

133176
@computed_field # type: ignore[misc]
134177
@cached_property
@@ -146,19 +189,19 @@ def _captions_of_some_item(self) -> set[str]:
146189
return refs
147190

148191
@override
149-
def get_excluded_refs(self, **kwargs) -> list[str]:
192+
def get_excluded_refs(self, **kwargs) -> set[str]:
150193
"""References to excluded items."""
151194
params = self.params.merge_with_patch(patch=kwargs)
152195
params_json = params.model_dump_json()
153196
refs = self._excluded_refs_cache.get(params_json)
154197
if refs is None:
155-
refs = [
198+
refs = {
156199
item.self_ref
157-
for ix, (item, _) in enumerate(
158-
self.doc.iterate_items(
159-
with_groups=True,
200+
for ix, item in enumerate(
201+
_iterate_items(
202+
doc=self.doc,
160203
traverse_pictures=True,
161-
included_content_layers=params.layers,
204+
layers=params.layers,
162205
)
163206
)
164207
if (
@@ -178,64 +221,21 @@ def get_excluded_refs(self, **kwargs) -> list[str]:
178221
)
179222
)
180223
)
181-
]
224+
}
182225
self._excluded_refs_cache[params_json] = refs
183226
return refs
184227

185-
@abstractmethod
186-
def serialize_page(
187-
self, *, parts: list[SerializationResult], **kwargs
188-
) -> SerializationResult:
189-
"""Serialize a page out of its parts."""
190-
...
191-
192228
@abstractmethod
193229
def serialize_doc(
194-
self, *, pages: dict[Optional[int], SerializationResult], **kwargs
230+
self, *, parts: list[SerializationResult], **kwargs
195231
) -> SerializationResult:
196232
"""Serialize a document out of its pages."""
197233
...
198234

199235
def _serialize_body(self) -> SerializationResult:
200236
"""Serialize the document body."""
201-
# find page ranges if available; otherwise regard whole doc as a single page
202-
prev_start: int = 0
203-
prev_page_nr: Optional[int] = None
204-
range_by_page_nr: dict[Optional[int], tuple[int, int]] = {}
205-
206-
for ix, (item, _) in enumerate(
207-
self.doc.iterate_items(
208-
with_groups=True,
209-
traverse_pictures=True,
210-
included_content_layers=self.params.layers,
211-
)
212-
):
213-
if isinstance(item, DocItem):
214-
if item.prov:
215-
page_no = item.prov[0].page_no
216-
if prev_page_nr is None or page_no > prev_page_nr:
217-
if prev_page_nr is not None: # close previous range
218-
range_by_page_nr[prev_page_nr] = (prev_start, ix)
219-
220-
prev_start = ix
221-
# could alternatively always start 1st page from 0:
222-
# prev_start = ix if prev_page_nr is not None else 0
223-
224-
prev_page_nr = page_no
225-
226-
# close last (and single if no pages) range
227-
range_by_page_nr[prev_page_nr] = (prev_start, sys.maxsize)
228-
229-
page_results: dict[Optional[int], SerializationResult] = {}
230-
for page_nr in range_by_page_nr:
231-
page_range = range_by_page_nr[page_nr]
232-
params_to_pass = deepcopy(self.params)
233-
params_to_pass.start_idx = page_range[0]
234-
params_to_pass.stop_idx = page_range[1]
235-
subparts = self.get_parts(**params_to_pass.model_dump())
236-
page_res = self.serialize_page(parts=subparts)
237-
page_results[page_nr] = page_res
238-
res = self.serialize_doc(pages=page_results)
237+
subparts = self.get_parts()
238+
res = self.serialize_doc(parts=subparts)
239239
return res
240240

241241
@override
@@ -331,6 +331,11 @@ def serialize(
331331
doc=self.doc,
332332
**my_kwargs,
333333
)
334+
elif isinstance(item, _PageBreakNode):
335+
part = _PageBreakSerResult(
336+
text=self._create_page_break(node=item),
337+
node=item,
338+
)
334339
else:
335340
part = self.fallback_serializer.serialize(
336341
item=item,
@@ -356,18 +361,19 @@ def get_parts(
356361
parts: list[SerializationResult] = []
357362
my_visited: set[str] = visited if visited is not None else set()
358363
params = self.params.merge_with_patch(patch=kwargs)
359-
for item, _ in self.doc.iterate_items(
360-
root=item,
361-
with_groups=True,
362-
traverse_pictures=traverse_pictures,
363-
included_content_layers=params.layers,
364+
365+
for node in _iterate_items(
366+
node=item,
367+
doc=self.doc,
368+
layers=params.layers,
369+
add_page_breaks=self.requires_page_break(),
364370
):
365-
if item.self_ref in my_visited:
371+
if node.self_ref in my_visited:
366372
continue
367373
else:
368-
my_visited.add(item.self_ref)
374+
my_visited.add(node.self_ref)
369375
part = self.serialize(
370-
item=item,
376+
item=node,
371377
list_level=list_level,
372378
is_inline_scope=is_inline_scope,
373379
visited=my_visited,
@@ -450,3 +456,38 @@ def serialize_captions(
450456
else:
451457
text_res = ""
452458
return create_ser_result(text=text_res, span_source=results)
459+
460+
def _get_applicable_pages(self) -> Optional[list[int]]:
461+
pages = {
462+
item.prov[0].page_no: ...
463+
for ix, (item, _) in enumerate(
464+
self.doc.iterate_items(
465+
with_groups=True,
466+
included_content_layers=self.params.layers,
467+
traverse_pictures=True,
468+
)
469+
)
470+
if (
471+
isinstance(item, DocItem)
472+
and item.prov
473+
and (
474+
self.params.pages is None
475+
or item.prov[0].page_no in self.params.pages
476+
)
477+
and ix >= self.params.start_idx
478+
and ix < self.params.stop_idx
479+
)
480+
}
481+
return [p for p in pages] or None
482+
483+
def _create_page_break(self, node: _PageBreakNode) -> str:
484+
return f"#_#_DOCLING_DOC_PAGE_BREAK_{node.prev_page}_{node.next_page}_#_#"
485+
486+
def _get_page_breaks(self, text: str) -> Iterable[Tuple[str, int, int]]:
487+
pattern = r"#_#_DOCLING_DOC_PAGE_BREAK_(\d+)_(\d+)_#_#"
488+
matches = re.finditer(pattern, text)
489+
for match in matches:
490+
full_match = match.group(0)
491+
prev_page_nr = int(match.group(1))
492+
next_page_nr = int(match.group(2))
493+
yield (full_match, prev_page_nr, next_page_nr)

docling_core/experimental/serializer/doctags.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -476,28 +476,21 @@ class DocTagsDocSerializer(DocSerializer):
476476
params: DocTagsParams = DocTagsParams()
477477

478478
@override
479-
def serialize_page(
479+
def serialize_doc(
480480
self, *, parts: list[SerializationResult], **kwargs
481481
) -> SerializationResult:
482-
"""Serialize a page out of its parts."""
482+
"""Serialize a document out of its pages."""
483483
delim = _get_delim(params=self.params)
484484
text_res = delim.join([p.text for p in parts if p.text])
485-
return create_ser_result(text=text_res, span_source=parts)
486485

487-
@override
488-
def serialize_doc(
489-
self, *, pages: dict[Optional[int], SerializationResult], **kwargs
490-
) -> SerializationResult:
491-
"""Serialize a document out of its pages."""
492-
delim = _get_delim(params=self.params)
493486
if self.params.add_page_break:
494-
page_sep = f"{delim}<{DocumentToken.PAGE_BREAK.value}>{delim}"
495-
content = page_sep.join([text for k in pages if (text := pages[k].text)])
496-
else:
497-
content = self.serialize_page(parts=list(pages.values())).text
487+
page_sep = f"<{DocumentToken.PAGE_BREAK.value}>"
488+
for full_match, _, _ in self._get_page_breaks(text=text_res):
489+
text_res = text_res.replace(full_match, page_sep)
490+
498491
wrap_tag = DocumentToken.DOCUMENT.value
499-
text_res = f"<{wrap_tag}>{content}{delim}</{wrap_tag}>"
500-
return create_ser_result(text=text_res, span_source=list(pages.values()))
492+
text_res = f"<{wrap_tag}>{text_res}{delim}</{wrap_tag}>"
493+
return create_ser_result(text=text_res, span_source=parts)
501494

502495
@override
503496
def serialize_captions(
@@ -526,3 +519,8 @@ def serialize_captions(
526519
if text_res:
527520
text_res = _wrap(text=text_res, wrap_tag=DocumentToken.CAPTION.value)
528521
return create_ser_result(text=text_res, span_source=results)
522+
523+
@override
524+
def requires_page_break(self):
525+
"""Whether to add page breaks."""
526+
return self.params.add_page_break

0 commit comments

Comments
 (0)