Skip to content

Commit ebc356a

Browse files
authored
fix: fix hybrid chunker legacy patching (#300)
Signed-off-by: Panos Vagenas <[email protected]>
1 parent 6274b1a commit ebc356a

File tree

2 files changed

+117
-34
lines changed

2 files changed

+117
-34
lines changed

docling_core/transforms/chunker/hybrid_chunker.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Any, Iterable, Iterator, Optional, Union
1010

1111
from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator
12+
from transformers import PreTrainedTokenizerBase
1213

1314
from docling_core.transforms.chunker.hierarchical_chunker import (
1415
ChunkingSerializerProvider,
@@ -70,23 +71,37 @@ class HybridChunker(BaseChunker):
7071
@model_validator(mode="before")
7172
@classmethod
7273
def _patch(cls, data: Any) -> Any:
73-
if isinstance(data, dict) and (tokenizer := data.get("tokenizer")):
74+
if isinstance(data, dict):
75+
tokenizer = data.get("tokenizer")
7476
max_tokens = data.get("max_tokens")
75-
if isinstance(tokenizer, BaseTokenizer):
76-
pass
77-
else:
77+
if not isinstance(tokenizer, BaseTokenizer) and (
78+
# some legacy param passed:
79+
tokenizer is not None
80+
or max_tokens is not None
81+
):
7882
from docling_core.transforms.chunker.tokenizer.huggingface import (
7983
HuggingFaceTokenizer,
8084
)
8185

86+
warnings.warn(
87+
"Deprecated initialization parameter types for HybridChunker. "
88+
"For updated usage check out "
89+
"https://docling-project.github.io/docling/examples/hybrid_chunking/",
90+
DeprecationWarning,
91+
stacklevel=3,
92+
)
93+
8294
if isinstance(tokenizer, str):
8395
data["tokenizer"] = HuggingFaceTokenizer.from_pretrained(
8496
model_name=tokenizer,
8597
max_tokens=max_tokens,
8698
)
87-
else:
88-
# migrate previous HF-based tokenizers
89-
kwargs = {"tokenizer": tokenizer}
99+
elif tokenizer is None or isinstance(
100+
tokenizer, PreTrainedTokenizerBase
101+
):
102+
kwargs = {
103+
"tokenizer": tokenizer or _get_default_tokenizer().tokenizer
104+
}
90105
if max_tokens is not None:
91106
kwargs["max_tokens"] = max_tokens
92107
data["tokenizer"] = HuggingFaceTokenizer(**kwargs)

test/test_hybrid_chunker.py

Lines changed: 95 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55

66
import json
7+
import warnings
78

89
import tiktoken
910
from transformers import AutoTokenizer
@@ -14,6 +15,7 @@
1415
DocChunk,
1516
)
1617
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
18+
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
1719
from docling_core.transforms.chunker.tokenizer.openai import OpenAITokenizer
1820
from docling_core.transforms.serializer.markdown import MarkdownTableSerializer
1921
from docling_core.types.doc import DoclingDocument as DLDocument
@@ -25,7 +27,7 @@
2527
MAX_TOKENS = 64
2628
INPUT_FILE = "test/data/chunker/2_inp_dl_doc.json"
2729

28-
TOKENIZER = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)
30+
INNER_TOKENIZER = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)
2931

3032

3133
def _process(act_data, exp_path_str):
@@ -47,8 +49,10 @@ def test_chunk_merge_peers():
4749
dl_doc = DLDocument.model_validate_json(data_json)
4850

4951
chunker = HybridChunker(
50-
tokenizer=TOKENIZER,
51-
max_tokens=MAX_TOKENS,
52+
tokenizer=HuggingFaceTokenizer(
53+
tokenizer=INNER_TOKENIZER,
54+
max_tokens=MAX_TOKENS,
55+
),
5256
merge_peers=True,
5357
)
5458

@@ -63,20 +67,48 @@ def test_chunk_merge_peers():
6367
)
6468

6569

66-
def test_chunk_no_merge_peers():
67-
EXPECTED_OUT_FILE = "test/data/chunker/2b_out_chunks.json"
70+
def test_chunk_with_model_name():
71+
EXPECTED_OUT_FILE = "test/data/chunker/2a_out_chunks.json"
6872

6973
with open(INPUT_FILE, encoding="utf-8") as f:
7074
data_json = f.read()
7175
dl_doc = DLDocument.model_validate_json(data_json)
7276

7377
chunker = HybridChunker(
74-
tokenizer=TOKENIZER,
75-
max_tokens=MAX_TOKENS,
76-
merge_peers=False,
78+
tokenizer=HuggingFaceTokenizer.from_pretrained(
79+
model_name=EMBED_MODEL_ID,
80+
max_tokens=MAX_TOKENS,
81+
),
82+
merge_peers=True,
7783
)
7884

79-
chunks = chunker.chunk(dl_doc=dl_doc)
85+
chunk_iter = chunker.chunk(dl_doc=dl_doc)
86+
chunks = list(chunk_iter)
87+
act_data = dict(
88+
root=[DocChunk.model_validate(n).export_json_dict() for n in chunks]
89+
)
90+
_process(
91+
act_data=act_data,
92+
exp_path_str=EXPECTED_OUT_FILE,
93+
)
94+
95+
96+
def test_chunk_deprecated_max_tokens():
97+
EXPECTED_OUT_FILE = "test/data/chunker/2a_out_chunks.json"
98+
99+
with open(INPUT_FILE, encoding="utf-8") as f:
100+
data_json = f.read()
101+
dl_doc = DLDocument.model_validate_json(data_json)
102+
103+
with warnings.catch_warnings(record=True) as w:
104+
chunker = HybridChunker(
105+
max_tokens=MAX_TOKENS,
106+
)
107+
assert len(w) == 1, "One deprecation warning was expected"
108+
assert issubclass(w[-1].category, DeprecationWarning)
109+
110+
chunk_iter = chunker.chunk(dl_doc=dl_doc)
111+
chunks = list(chunk_iter)
80112
act_data = dict(
81113
root=[DocChunk.model_validate(n).export_json_dict() for n in chunks]
82114
)
@@ -94,8 +126,10 @@ def test_contextualize():
94126
dl_doc = DLDocument.model_validate_json(data_json)
95127

96128
chunker = HybridChunker(
97-
tokenizer=TOKENIZER,
98-
max_tokens=MAX_TOKENS,
129+
tokenizer=HuggingFaceTokenizer(
130+
tokenizer=INNER_TOKENIZER,
131+
max_tokens=MAX_TOKENS,
132+
),
99133
merge_peers=True,
100134
)
101135

@@ -106,7 +140,7 @@ def test_contextualize():
106140
dict(
107141
text=chunk.text,
108142
ser_text=(ser_text := chunker.contextualize(chunk)),
109-
num_tokens=len(TOKENIZER.tokenize(ser_text)),
143+
num_tokens=len(INNER_TOKENIZER.tokenize(ser_text)),
110144
)
111145
for chunk in chunks
112146
]
@@ -117,21 +151,22 @@ def test_contextualize():
117151
)
118152

119153

120-
def test_chunk_with_model_name():
121-
EXPECTED_OUT_FILE = "test/data/chunker/2a_out_chunks.json"
154+
def test_chunk_no_merge_peers():
155+
EXPECTED_OUT_FILE = "test/data/chunker/2b_out_chunks.json"
122156

123157
with open(INPUT_FILE, encoding="utf-8") as f:
124158
data_json = f.read()
125159
dl_doc = DLDocument.model_validate_json(data_json)
126160

127161
chunker = HybridChunker(
128-
tokenizer=EMBED_MODEL_ID,
129-
max_tokens=MAX_TOKENS,
130-
merge_peers=True,
162+
tokenizer=HuggingFaceTokenizer(
163+
tokenizer=INNER_TOKENIZER,
164+
max_tokens=MAX_TOKENS,
165+
),
166+
merge_peers=False,
131167
)
132168

133-
chunk_iter = chunker.chunk(dl_doc=dl_doc)
134-
chunks = list(chunk_iter)
169+
chunks = chunker.chunk(dl_doc=dl_doc)
135170
act_data = dict(
136171
root=[DocChunk.model_validate(n).export_json_dict() for n in chunks]
137172
)
@@ -161,17 +196,43 @@ def test_chunk_default():
161196
)
162197

163198

164-
def test_chunk_excplicit_hf_obj():
199+
def test_chunk_deprecated_explicit_hf_obj():
165200
EXPECTED_OUT_FILE = "test/data/chunker/2c_out_chunks.json"
166201

167202
with open(INPUT_FILE, encoding="utf-8") as f:
168203
data_json = f.read()
169204
dl_doc = DLDocument.model_validate_json(data_json)
170205

171-
chunker = HybridChunker(
172-
tokenizer=AutoTokenizer.from_pretrained(
173-
"sentence-transformers/all-MiniLM-L6-v2"
206+
with warnings.catch_warnings(record=True) as w:
207+
chunker = HybridChunker(
208+
tokenizer=INNER_TOKENIZER,
174209
)
210+
assert len(w) == 1, "One deprecation warning was expected"
211+
assert issubclass(w[-1].category, DeprecationWarning)
212+
213+
chunk_iter = chunker.chunk(dl_doc=dl_doc)
214+
chunks = list(chunk_iter)
215+
act_data = dict(
216+
root=[DocChunk.model_validate(n).export_json_dict() for n in chunks]
217+
)
218+
_process(
219+
act_data=act_data,
220+
exp_path_str=EXPECTED_OUT_FILE,
221+
)
222+
223+
224+
def test_ignore_deprecated_param_if_new_tokenizer_passed():
225+
EXPECTED_OUT_FILE = "test/data/chunker/2c_out_chunks.json"
226+
227+
with open(INPUT_FILE, encoding="utf-8") as f:
228+
data_json = f.read()
229+
dl_doc = DLDocument.model_validate_json(data_json)
230+
231+
chunker = HybridChunker(
232+
tokenizer=HuggingFaceTokenizer(
233+
tokenizer=INNER_TOKENIZER,
234+
),
235+
max_tokens=MAX_TOKENS,
175236
)
176237

177238
chunk_iter = chunker.chunk(dl_doc=dl_doc)
@@ -193,7 +254,12 @@ def test_contextualize_altered_delim():
193254
dl_doc = DLDocument.model_validate_json(data_json)
194255

195256
chunker = HybridChunker(
196-
tokenizer=TOKENIZER, max_tokens=MAX_TOKENS, merge_peers=True, delim="####"
257+
tokenizer=HuggingFaceTokenizer(
258+
tokenizer=INNER_TOKENIZER,
259+
max_tokens=MAX_TOKENS,
260+
),
261+
merge_peers=True,
262+
delim="####",
197263
)
198264

199265
chunks = chunker.chunk(dl_doc=dl_doc)
@@ -203,7 +269,7 @@ def test_contextualize_altered_delim():
203269
dict(
204270
text=chunk.text,
205271
ser_text=(ser_text := chunker.contextualize(chunk)),
206-
num_tokens=len(TOKENIZER.tokenize(ser_text)),
272+
num_tokens=len(INNER_TOKENIZER.tokenize(ser_text)),
207273
)
208274
for chunk in chunks
209275
]
@@ -229,8 +295,10 @@ def get_serializer(self, doc: DoclingDocument):
229295
)
230296

231297
chunker = HybridChunker(
232-
tokenizer=TOKENIZER,
233-
max_tokens=MAX_TOKENS,
298+
tokenizer=HuggingFaceTokenizer(
299+
tokenizer=INNER_TOKENIZER,
300+
max_tokens=MAX_TOKENS,
301+
),
234302
merge_peers=True,
235303
serializer_provider=MySerializerProvider(),
236304
)

0 commit comments

Comments
 (0)