Skip to content

Commit 23036e1

Browse files
authored
feat: add serializer provider to chunkers (#239)
Signed-off-by: Panos Vagenas <[email protected]>
1 parent 055742c commit 23036e1

File tree

6 files changed

+114
-52
lines changed

6 files changed

+114
-52
lines changed

docling_core/experimental/serializer/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,12 @@ def serialize_captions(
237237
def get_excluded_refs(self, **kwargs) -> list[str]:
238238
"""Get references to excluded items."""
239239
...
240+
241+
242+
class BaseSerializerProvider(ABC):
243+
"""Base class for document serializer providers."""
244+
245+
@abstractmethod
246+
def get_serializer(self, doc: DoclingDocument) -> BaseDocSerializer:
247+
"""Get a the associated serializer."""
248+
...

docling_core/transforms/chunker/hierarchical_chunker.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
import re
1212
from typing import Any, ClassVar, Final, Iterator, Literal, Optional
1313

14-
from pydantic import Field, StringConstraints, field_validator
14+
from pydantic import ConfigDict, Field, StringConstraints, field_validator
1515
from typing_extensions import Annotated, override
1616

1717
from docling_core.experimental.serializer.base import (
1818
BaseDocSerializer,
19+
BaseSerializerProvider,
1920
BaseTableSerializer,
2021
SerializationResult,
2122
)
@@ -183,6 +184,15 @@ class ChunkingDocSerializer(MarkdownDocSerializer):
183184
)
184185

185186

187+
class ChunkingSerializerProvider(BaseSerializerProvider):
188+
"""Serializer provider used for chunking purposes."""
189+
190+
@override
191+
def get_serializer(self, doc: DoclingDocument) -> BaseDocSerializer:
192+
"""Get the associated serializer."""
193+
return ChunkingDocSerializer(doc=doc)
194+
195+
186196
class HierarchicalChunker(BaseChunker):
187197
r"""Chunker implementation leveraging the document layout.
188198
@@ -192,12 +202,16 @@ class HierarchicalChunker(BaseChunker):
192202
delim (str): Delimiter to use for merging text. Defaults to "\n".
193203
"""
194204

205+
model_config = ConfigDict(arbitrary_types_allowed=True)
206+
207+
serializer_provider: BaseSerializerProvider = ChunkingSerializerProvider()
208+
209+
# deprecated:
195210
merge_list_items: Annotated[bool, Field(deprecated=True)] = True
196211

197212
def chunk(
198213
self,
199214
dl_doc: DLDocument,
200-
doc_serializer: Optional[BaseDocSerializer] = None,
201215
**kwargs: Any,
202216
) -> Iterator[BaseChunk]:
203217
r"""Chunk the provided document.
@@ -208,7 +222,7 @@ def chunk(
208222
Yields:
209223
Iterator[Chunk]: iterator over extracted chunks
210224
"""
211-
my_doc_ser = doc_serializer or ChunkingDocSerializer(doc=dl_doc)
225+
my_doc_ser = self.serializer_provider.get_serializer(doc=dl_doc)
212226
heading_by_level: dict[LevelNumber, str] = {}
213227
visited: set[str] = set()
214228
ser_res = create_ser_result()

docling_core/transforms/chunker/hybrid_chunker.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,23 @@
44
#
55

66
"""Hybrid chunker implementation leveraging both doc structure & token awareness."""
7-
87
import warnings
8+
from functools import cached_property
99
from typing import Any, Iterable, Iterator, Optional, Union
1010

11-
from pydantic import BaseModel, ConfigDict, PositiveInt, TypeAdapter, model_validator
11+
from pydantic import (
12+
BaseModel,
13+
ConfigDict,
14+
PositiveInt,
15+
TypeAdapter,
16+
computed_field,
17+
model_validator,
18+
)
1219
from typing_extensions import Self
1320

14-
from docling_core.transforms.chunker.hierarchical_chunker import ChunkingDocSerializer
21+
from docling_core.transforms.chunker.hierarchical_chunker import (
22+
ChunkingSerializerProvider,
23+
)
1524

1625
try:
1726
import semchunk
@@ -22,7 +31,10 @@
2231
"`pip install 'docling-core[chunking]'`"
2332
)
2433

25-
from docling_core.experimental.serializer.base import BaseDocSerializer
34+
from docling_core.experimental.serializer.base import (
35+
BaseDocSerializer,
36+
BaseSerializerProvider,
37+
)
2638
from docling_core.transforms.chunker import (
2739
BaseChunk,
2840
BaseChunker,
@@ -52,7 +64,7 @@ class HybridChunker(BaseChunker):
5264
max_tokens: int = None # type: ignore[assignment]
5365
merge_peers: bool = True
5466

55-
_inner_chunker: HierarchicalChunker = HierarchicalChunker()
67+
serializer_provider: BaseSerializerProvider = ChunkingSerializerProvider()
5668

5769
@model_validator(mode="after")
5870
def _patch_tokenizer_and_max_tokens(self) -> Self:
@@ -67,6 +79,11 @@ def _patch_tokenizer_and_max_tokens(self) -> Self:
6779
)
6880
return self
6981

82+
@computed_field # type: ignore[misc]
83+
@cached_property
84+
def _inner_chunker(self) -> HierarchicalChunker:
85+
return HierarchicalChunker(serializer_provider=self.serializer_provider)
86+
7087
def _count_text_tokens(self, text: Optional[Union[str, list[str]]]):
7188
if text is None:
7289
return 0
@@ -246,7 +263,6 @@ def _merge_chunks_with_matching_metadata(self, chunks: list[DocChunk]):
246263
def chunk(
247264
self,
248265
dl_doc: DoclingDocument,
249-
doc_serializer: Optional[BaseDocSerializer] = None,
250266
**kwargs: Any,
251267
) -> Iterator[BaseChunk]:
252268
r"""Chunk the provided document.
@@ -257,7 +273,7 @@ def chunk(
257273
Yields:
258274
Iterator[Chunk]: iterator over extracted chunks
259275
"""
260-
my_doc_ser = doc_serializer or ChunkingDocSerializer(doc=dl_doc)
276+
my_doc_ser = self.serializer_provider.get_serializer(doc=dl_doc)
261277
res: Iterable[DocChunk]
262278
res = self._inner_chunker.chunk(
263279
dl_doc=dl_doc,

examples/chunking_and_serialization.ipynb

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,12 @@
6666
"cell_type": "code",
6767
"execution_count": 2,
6868
"metadata": {},
69-
"outputs": [
70-
{
71-
"name": "stdout",
72-
"output_type": "stream",
73-
"text": [
74-
"chunker.max_tokens=512\n"
75-
]
76-
}
77-
],
69+
"outputs": [],
7870
"source": [
7971
"from transformers import AutoTokenizer\n",
8072
"\n",
8173
"EMBED_MODEL_ID = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
82-
"tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)\n",
83-
"chunker = HybridChunker(tokenizer=tokenizer)\n",
84-
"print(f\"{chunker.max_tokens=}\")"
74+
"tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)"
8575
]
8676
},
8777
{
@@ -202,6 +192,8 @@
202192
}
203193
],
204194
"source": [
195+
"chunker = HybridChunker(tokenizer=tokenizer)\n",
196+
"\n",
205197
"chunk_iter = chunker.chunk(dl_doc=doc)\n",
206198
"\n",
207199
"chunks = list(chunk_iter)\n",
@@ -279,12 +271,22 @@
279271
}
280272
],
281273
"source": [
282-
"doc_serializer = ChunkingDocSerializer(\n",
283-
" doc=doc,\n",
284-
" table_serializer=MarkdownTableSerializer(), # configuring a different table serializer\n",
274+
"from docling_core.transforms.chunker.hierarchical_chunker import ChunkingSerializerProvider\n",
275+
"\n",
276+
"\n",
277+
"class MDTableSerializerProvider(ChunkingSerializerProvider):\n",
278+
" def get_serializer(self, doc):\n",
279+
" return ChunkingDocSerializer(\n",
280+
" doc=doc,\n",
281+
" table_serializer=MarkdownTableSerializer(), # configuring a different table serializer\n",
282+
" )\n",
283+
"\n",
284+
"chunker = HybridChunker(\n",
285+
" tokenizer=tokenizer,\n",
286+
" serializer_provider=MDTableSerializerProvider(),\n",
285287
")\n",
286288
"\n",
287-
"chunk_iter = chunker.chunk(dl_doc=doc, doc_serializer=doc_serializer)\n",
289+
"chunk_iter = chunker.chunk(dl_doc=doc)\n",
288290
"\n",
289291
"chunks = list(chunk_iter)\n",
290292
"i, chunk = find_n_th_chunk_with_label(chunks, n=0, label=DocItemLabel.TABLE)\n",
@@ -355,14 +357,21 @@
355357
"source": [
356358
"from docling_core.experimental.serializer.markdown import MarkdownParams\n",
357359
"\n",
358-
"doc_serializer = ChunkingDocSerializer(\n",
359-
" doc=doc,\n",
360-
" params=MarkdownParams(\n",
361-
" image_placeholder=\"<!-- image -->\",\n",
362-
" ),\n",
360+
"class ImgPlaceholderSerializerProvider(ChunkingSerializerProvider):\n",
361+
" def get_serializer(self, doc):\n",
362+
" return ChunkingDocSerializer(\n",
363+
" doc=doc,\n",
364+
" params=MarkdownParams(\n",
365+
" image_placeholder=\"<!-- image -->\",\n",
366+
" ),\n",
367+
" )\n",
368+
"\n",
369+
"chunker = HybridChunker(\n",
370+
" tokenizer=tokenizer,\n",
371+
" serializer_provider=ImgPlaceholderSerializerProvider(),\n",
363372
")\n",
364373
"\n",
365-
"chunk_iter = chunker.chunk(dl_doc=doc, doc_serializer=doc_serializer)\n",
374+
"chunk_iter = chunker.chunk(dl_doc=doc)\n",
366375
"\n",
367376
"chunks = list(chunk_iter)\n",
368377
"i, chunk = find_n_th_chunk_with_label(chunks, n=0, label=DocItemLabel.PICTURE)\n",
@@ -466,12 +475,19 @@
466475
}
467476
],
468477
"source": [
469-
"doc_serializer = ChunkingDocSerializer(\n",
470-
" doc=doc,\n",
471-
" picture_serializer=AnnotationPictureSerializer(), # configuring a different picture serializer\n",
478+
"class ImgAnnotationSerializerProvider(ChunkingSerializerProvider):\n",
479+
" def get_serializer(self, doc):\n",
480+
" return ChunkingDocSerializer(\n",
481+
" doc=doc,\n",
482+
" picture_serializer=AnnotationPictureSerializer(), # configuring a different picture serializer\n",
483+
" )\n",
484+
"\n",
485+
"chunker = HybridChunker(\n",
486+
" tokenizer=tokenizer,\n",
487+
" serializer_provider=ImgAnnotationSerializerProvider(),\n",
472488
")\n",
473489
"\n",
474-
"chunk_iter = chunker.chunk(dl_doc=doc, doc_serializer=doc_serializer)\n",
490+
"chunk_iter = chunker.chunk(dl_doc=doc)\n",
475491
"\n",
476492
"chunks = list(chunk_iter)\n",
477493
"i, chunk = find_n_th_chunk_with_label(chunks, n=0, label=DocItemLabel.PICTURE)\n",

test/test_hierarchical_chunker.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
from docling_core.transforms.chunker import HierarchicalChunker
1010
from docling_core.transforms.chunker.hierarchical_chunker import (
1111
ChunkingDocSerializer,
12+
ChunkingSerializerProvider,
1213
DocChunk,
1314
)
1415
from docling_core.types.doc import DoclingDocument as DLDocument
16+
from docling_core.types.doc.document import DoclingDocument
1517

1618
from .test_data_gen_flag import GEN_TEST_DATA
1719

@@ -48,18 +50,20 @@ def test_chunk_custom_serializer():
4850
with open("test/data/chunker/0_inp_dl_doc.json", encoding="utf-8") as f:
4951
data_json = f.read()
5052
dl_doc = DLDocument.model_validate_json(data_json)
53+
54+
class MySerializerProvider(ChunkingSerializerProvider):
55+
def get_serializer(self, doc: DoclingDocument):
56+
return ChunkingDocSerializer(
57+
doc=doc,
58+
table_serializer=MarkdownTableSerializer(),
59+
)
60+
5161
chunker = HierarchicalChunker(
5262
merge_list_items=True,
53-
)
54-
doc_serializer = ChunkingDocSerializer(
55-
doc=dl_doc,
56-
table_serializer=MarkdownTableSerializer(),
63+
serializer_provider=MySerializerProvider(),
5764
)
5865

59-
chunks = chunker.chunk(
60-
dl_doc=dl_doc,
61-
doc_serializer=doc_serializer,
62-
)
66+
chunks = chunker.chunk(dl_doc=dl_doc)
6367
act_data = dict(
6468
root=[DocChunk.model_validate(n).export_json_dict() for n in chunks]
6569
)

test/test_hybrid_chunker.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
from docling_core.experimental.serializer.markdown import MarkdownTableSerializer
1111
from docling_core.transforms.chunker.hierarchical_chunker import (
1212
ChunkingDocSerializer,
13+
ChunkingSerializerProvider,
1314
DocChunk,
1415
)
1516
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
1617
from docling_core.types.doc import DoclingDocument as DLDocument
18+
from docling_core.types.doc.document import DoclingDocument
1719

1820
from .test_data_gen_flag import GEN_TEST_DATA
1921

@@ -193,20 +195,21 @@ def test_chunk_custom_serializer():
193195
data_json = f.read()
194196
dl_doc = DLDocument.model_validate_json(data_json)
195197

198+
class MySerializerProvider(ChunkingSerializerProvider):
199+
def get_serializer(self, doc: DoclingDocument):
200+
return ChunkingDocSerializer(
201+
doc=doc,
202+
table_serializer=MarkdownTableSerializer(),
203+
)
204+
196205
chunker = HybridChunker(
197206
tokenizer=TOKENIZER,
198207
max_tokens=MAX_TOKENS,
199208
merge_peers=True,
200-
)
201-
doc_serializer = ChunkingDocSerializer(
202-
doc=dl_doc,
203-
table_serializer=MarkdownTableSerializer(), # configuring a different table serializer
209+
serializer_provider=MySerializerProvider(),
204210
)
205211

206-
chunk_iter = chunker.chunk(
207-
dl_doc=dl_doc,
208-
doc_serializer=doc_serializer,
209-
)
212+
chunk_iter = chunker.chunk(dl_doc=dl_doc)
210213
chunks = list(chunk_iter)
211214
act_data = dict(
212215
root=[DocChunk.model_validate(n).export_json_dict() for n in chunks]

0 commit comments

Comments
 (0)