Skip to content

Commit 628ab67

Browse files
vagenasjwm4bash99
authored
feat: add hybrid chunker (#68)
Signed-off-by: Panos Vagenas <[email protected]> Co-authored-by: Bill Murdock <[email protected]> Co-authored-by: Ben Rood <[email protected]>
1 parent 4dd1c87 commit 628ab67

File tree

11 files changed

+4492
-3
lines changed

11 files changed

+4492
-3
lines changed

docling_core/transforms/chunker/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from docling_core.transforms.chunker.base import BaseChunk, BaseChunker, BaseMeta
99
from docling_core.transforms.chunker.hierarchical_chunker import (
10+
DocChunk,
1011
DocMeta,
1112
HierarchicalChunker,
1213
)

docling_core/transforms/chunker/base.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
#
55

66
"""Define base classes for chunking."""
7+
import json
78
from abc import ABC, abstractmethod
89
from typing import Any, ClassVar, Iterator
910

1011
from pydantic import BaseModel
1112

1213
from docling_core.types.doc import DoclingDocument as DLDocument
1314

15+
DFLT_DELIM = "\n"
16+
1417

1518
class BaseMeta(BaseModel):
1619
"""Chunk metadata base class."""
@@ -45,6 +48,8 @@ def export_json_dict(self) -> dict[str, Any]:
4548
class BaseChunker(BaseModel, ABC):
4649
"""Chunker base class."""
4750

51+
delim: str = DFLT_DELIM
52+
4853
@abstractmethod
4954
def chunk(self, dl_doc: DLDocument, **kwargs) -> Iterator[BaseChunk]:
5055
"""Chunk the provided document.
@@ -59,3 +64,32 @@ def chunk(self, dl_doc: DLDocument, **kwargs) -> Iterator[BaseChunk]:
5964
Iterator[BaseChunk]: iterator over extracted chunks
6065
"""
6166
raise NotImplementedError()
67+
68+
def serialize(self, chunk: BaseChunk) -> str:
69+
"""Serialize the given chunk. This base implementation is embedding-targeted.
70+
71+
Args:
72+
chunk: chunk to serialize
73+
74+
Returns:
75+
str: the serialized form of the chunk
76+
"""
77+
meta = chunk.meta.export_json_dict()
78+
79+
items = []
80+
for k in meta:
81+
if k not in chunk.meta.excluded_embed:
82+
if isinstance(meta[k], list):
83+
items.append(
84+
self.delim.join(
85+
[
86+
d if isinstance(d, str) else json.dumps(d)
87+
for d in meta[k]
88+
]
89+
)
90+
)
91+
else:
92+
items.append(json.dumps(meta[k]))
93+
items.append(chunk.text)
94+
95+
return self.delim.join(items)

docling_core/transforms/chunker/hierarchical_chunker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def check_version_is_compatible(cls, v: str) -> str:
104104

105105

106106
class DocChunk(BaseChunk):
107-
"""Data model for Hierarchical Chunker chunks."""
107+
"""Data model for document chunks."""
108108

109109
meta: DocMeta
110110

@@ -119,7 +119,6 @@ class HierarchicalChunker(BaseChunker):
119119
"""
120120

121121
merge_list_items: bool = True
122-
delim: str = "\n"
123122

124123
@classmethod
125124
def _triplet_serialize(cls, table_df: DataFrame) -> str:
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
#
2+
# Copyright IBM Corp. 2024 - 2024
3+
# SPDX-License-Identifier: MIT
4+
#
5+
6+
"""Hybrid chunker implementation leveraging both doc structure & token awareness."""
7+
8+
import warnings
9+
from typing import Iterable, Iterator, Optional, Union
10+
11+
from pydantic import BaseModel, ConfigDict, PositiveInt, TypeAdapter, model_validator
12+
from typing_extensions import Self
13+
14+
try:
15+
import semchunk
16+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
17+
except ImportError:
18+
raise RuntimeError(
19+
"Module requires 'chunking' extra; to install, run: "
20+
"`pip install 'docling-core[chunking]'`"
21+
)
22+
23+
from docling_core.transforms.chunker import (
24+
BaseChunk,
25+
BaseChunker,
26+
DocChunk,
27+
DocMeta,
28+
HierarchicalChunker,
29+
)
30+
from docling_core.types import DoclingDocument
31+
from docling_core.types.doc.document import TextItem
32+
33+
34+
class HybridChunker(BaseChunker):
35+
r"""Chunker doing tokenization-aware refinements on top of document layout chunking.
36+
37+
Args:
38+
tokenizer: The tokenizer to use; either instantiated object or name or path of
39+
respective pretrained model
40+
max_tokens: The maximum number of tokens per chunk. If not set, limit is
41+
resolved from the tokenizer
42+
merge_peers: Whether to merge undersized chunks sharing same relevant metadata
43+
"""
44+
45+
model_config = ConfigDict(arbitrary_types_allowed=True)
46+
47+
tokenizer: Union[PreTrainedTokenizerBase, str]
48+
max_tokens: int = None # type: ignore[assignment]
49+
merge_peers: bool = True
50+
51+
_inner_chunker: HierarchicalChunker = HierarchicalChunker()
52+
53+
@model_validator(mode="after")
54+
def _patch_tokenizer_and_max_tokens(self) -> Self:
55+
self._tokenizer = (
56+
self.tokenizer
57+
if isinstance(self.tokenizer, PreTrainedTokenizerBase)
58+
else AutoTokenizer.from_pretrained(self.tokenizer)
59+
)
60+
if self.max_tokens is None:
61+
self.max_tokens = TypeAdapter(PositiveInt).validate_python(
62+
self._tokenizer.model_max_length
63+
)
64+
return self
65+
66+
def _count_tokens(self, text: Optional[Union[str, list[str]]]):
67+
if text is None:
68+
return 0
69+
elif isinstance(text, list):
70+
total = 0
71+
for t in text:
72+
total += self._count_tokens(t)
73+
return total
74+
return len(self._tokenizer.tokenize(text, max_length=None))
75+
76+
class _ChunkLengthInfo(BaseModel):
77+
total_len: int
78+
text_len: int
79+
other_len: int
80+
81+
def _doc_chunk_length(self, doc_chunk: DocChunk):
82+
text_length = self._count_tokens(doc_chunk.text)
83+
headings_length = self._count_tokens(doc_chunk.meta.headings)
84+
captions_length = self._count_tokens(doc_chunk.meta.captions)
85+
total = text_length + headings_length + captions_length
86+
return self._ChunkLengthInfo(
87+
total_len=total,
88+
text_len=text_length,
89+
other_len=total - text_length,
90+
)
91+
92+
def _make_chunk_from_doc_items(
93+
self, doc_chunk: DocChunk, window_text: str, window_start: int, window_end: int
94+
):
95+
meta = DocMeta(
96+
doc_items=doc_chunk.meta.doc_items[window_start : window_end + 1],
97+
headings=doc_chunk.meta.headings,
98+
captions=doc_chunk.meta.captions,
99+
)
100+
new_chunk = DocChunk(text=window_text, meta=meta)
101+
return new_chunk
102+
103+
def _merge_text(self, t1, t2):
104+
if t1 == "":
105+
return t2
106+
elif t2 == "":
107+
return t1
108+
else:
109+
return f"{t1}{self.delim}{t2}"
110+
111+
def _split_by_doc_items(self, doc_chunk: DocChunk) -> list[DocChunk]:
112+
if doc_chunk.meta.doc_items is None or len(doc_chunk.meta.doc_items) <= 1:
113+
return [doc_chunk]
114+
length = self._doc_chunk_length(doc_chunk)
115+
if length.total_len <= self.max_tokens:
116+
return [doc_chunk]
117+
else:
118+
chunks = []
119+
window_start = 0
120+
window_end = 0
121+
window_text = ""
122+
window_text_length = 0
123+
other_length = length.other_len
124+
num_items = len(doc_chunk.meta.doc_items)
125+
while window_end < num_items:
126+
doc_item = doc_chunk.meta.doc_items[window_end]
127+
if isinstance(doc_item, TextItem):
128+
text = doc_item.text
129+
else:
130+
raise RuntimeError("Non-TextItem split not implemented yet")
131+
text_length = self._count_tokens(text)
132+
if (
133+
text_length + window_text_length + other_length < self.max_tokens
134+
and window_end < num_items - 1
135+
):
136+
# Still room left to add more to this chunk AND still at least one
137+
# item left
138+
window_end += 1
139+
window_text_length += text_length
140+
window_text = self._merge_text(window_text, text)
141+
elif text_length + window_text_length + other_length < self.max_tokens:
142+
# All the items in the window fit into the chunk and there are no
143+
# other items left
144+
window_text = self._merge_text(window_text, text)
145+
new_chunk = self._make_chunk_from_doc_items(
146+
doc_chunk, window_text, window_start, window_end
147+
)
148+
chunks.append(new_chunk)
149+
window_end = num_items
150+
elif window_start == window_end:
151+
# Only one item in the window and it doesn't fit into the chunk. So
152+
# we'll just make it a chunk for now and it will get split in the
153+
# plain text splitter.
154+
window_text = self._merge_text(window_text, text)
155+
new_chunk = self._make_chunk_from_doc_items(
156+
doc_chunk, window_text, window_start, window_end
157+
)
158+
chunks.append(new_chunk)
159+
window_start = window_end + 1
160+
window_end = window_start
161+
window_text = ""
162+
window_text_length = 0
163+
else:
164+
# Multiple items in the window but they don't fit into the chunk.
165+
# However, the existing items must have fit or we wouldn't have
166+
# gotten here. So we put everything but the last item into the chunk
167+
# and then start a new window INCLUDING the current window end.
168+
new_chunk = self._make_chunk_from_doc_items(
169+
doc_chunk, window_text, window_start, window_end - 1
170+
)
171+
chunks.append(new_chunk)
172+
window_start = window_end
173+
window_text = ""
174+
window_text_length = 0
175+
return chunks
176+
177+
def _split_using_plain_text(
178+
self,
179+
doc_chunk: DocChunk,
180+
) -> list[DocChunk]:
181+
lengths = self._doc_chunk_length(doc_chunk)
182+
if lengths.total_len <= self.max_tokens:
183+
return [DocChunk(**doc_chunk.export_json_dict())]
184+
else:
185+
# How much room is there for text after subtracting out the headers and
186+
# captions:
187+
available_length = self.max_tokens - lengths.other_len
188+
sem_chunker = semchunk.chunkerify(
189+
self._tokenizer, chunk_size=available_length
190+
)
191+
if available_length <= 0:
192+
warnings.warn(
193+
f"Headers and captions for this chunk are longer than the total amount of size for the chunk, chunk will be ignored: {doc_chunk.text=}" # noqa
194+
)
195+
return []
196+
text = doc_chunk.text
197+
segments = sem_chunker.chunk(text)
198+
chunks = [DocChunk(text=s, meta=doc_chunk.meta) for s in segments]
199+
return chunks
200+
201+
def _merge_chunks_with_matching_metadata(self, chunks: list[DocChunk]):
202+
output_chunks = []
203+
window_start = 0
204+
window_end = 0
205+
num_chunks = len(chunks)
206+
while window_end < num_chunks:
207+
chunk = chunks[window_end]
208+
lengths = self._doc_chunk_length(chunk)
209+
headings_and_captions = (chunk.meta.headings, chunk.meta.captions)
210+
ready_to_append = False
211+
if window_start == window_end:
212+
# starting a new block of chunks to potentially merge
213+
current_headings_and_captions = headings_and_captions
214+
window_text = chunk.text
215+
window_other_length = lengths.other_len
216+
window_text_length = lengths.text_len
217+
window_items = chunk.meta.doc_items
218+
window_end += 1
219+
first_chunk_of_window = chunk
220+
elif (
221+
headings_and_captions == current_headings_and_captions
222+
and window_text_length + window_other_length + lengths.text_len
223+
<= self.max_tokens
224+
):
225+
# there is room to include the new chunk so add it to the window and
226+
# continue
227+
window_text = self._merge_text(window_text, chunk.text)
228+
window_text_length += lengths.text_len
229+
window_items = window_items + chunk.meta.doc_items
230+
window_end += 1
231+
else:
232+
ready_to_append = True
233+
234+
if ready_to_append or window_end == num_chunks:
235+
# no more room OR the start of new metadata. Either way, end the block
236+
# and use the current window_end as the start of a new block
237+
if window_start + 1 == window_end:
238+
# just one chunk so use it as is
239+
output_chunks.append(first_chunk_of_window)
240+
else:
241+
new_meta = DocMeta(
242+
doc_items=window_items,
243+
headings=current_headings_and_captions[0],
244+
captions=current_headings_and_captions[1],
245+
)
246+
new_chunk = DocChunk(
247+
text=window_text,
248+
meta=new_meta,
249+
)
250+
output_chunks.append(new_chunk)
251+
# no need to reset window_text, etc. because that will be reset in the
252+
# next iteration in the if window_start == window_end block
253+
window_start = window_end
254+
255+
return output_chunks
256+
257+
def chunk(self, dl_doc: DoclingDocument, **kwargs) -> Iterator[BaseChunk]:
258+
r"""Chunk the provided document.
259+
260+
Args:
261+
dl_doc (DLDocument): document to chunk
262+
263+
Yields:
264+
Iterator[Chunk]: iterator over extracted chunks
265+
"""
266+
res: Iterable[DocChunk]
267+
res = self._inner_chunker.chunk(dl_doc=dl_doc, **kwargs) # type: ignore
268+
res = [x for c in res for x in self._split_by_doc_items(c)]
269+
res = [x for c in res for x in self._split_using_plain_text(c)]
270+
if self.merge_peers:
271+
res = self._merge_chunks_with_matching_metadata(res)
272+
return iter(res)

0 commit comments

Comments
 (0)