Skip to content

Commit 763e136

Browse files
ruizguillevagenas
andauthored
feat: Add tiktoken tokenizers support to HybridChunker (#240)
* feat: Add tiktoken tokenizers support to HybridChunker Signed-off-by: ruizguille <[email protected]> * separate OpenAI tokenizer Signed-off-by: Panos Vagenas <[email protected]> --------- Signed-off-by: ruizguille <[email protected]> Signed-off-by: Panos Vagenas <[email protected]> Co-authored-by: Panos Vagenas <[email protected]>
1 parent c19c516 commit 763e136

File tree

10 files changed

+831
-64
lines changed

10 files changed

+831
-64
lines changed

docling_core/transforms/chunker/hybrid_chunker.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,21 @@
88
from functools import cached_property
99
from typing import Any, Iterable, Iterator, Optional, Union
1010

11-
from pydantic import (
12-
BaseModel,
13-
ConfigDict,
14-
PositiveInt,
15-
TypeAdapter,
16-
computed_field,
17-
model_validator,
18-
)
19-
from typing_extensions import Self
11+
from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator
2012

2113
from docling_core.transforms.chunker.hierarchical_chunker import (
2214
ChunkingSerializerProvider,
2315
)
16+
from docling_core.transforms.chunker.tokenizer.base import BaseTokenizer
2417

2518
try:
2619
import semchunk
27-
from transformers import AutoTokenizer, PreTrainedTokenizerBase
2820
except ImportError:
2921
raise RuntimeError(
30-
"Module requires 'chunking' extra; to install, run: "
31-
"`pip install 'docling-core[chunking]'`"
22+
"Extra required by module: 'chunking' by default (or 'chunking-openai' if "
23+
"specifically using OpenAI tokenization); to install, run: "
24+
"`pip install 'docling-core[chunking]'` or "
25+
"`pip install 'docling-core[chunking-openai]'`"
3226
)
3327

3428
from docling_core.experimental.serializer.base import (
@@ -45,6 +39,16 @@
4539
from docling_core.types import DoclingDocument
4640

4741

42+
def _get_default_tokenizer():
43+
from docling_core.transforms.chunker.tokenizer.huggingface import (
44+
HuggingFaceTokenizer,
45+
)
46+
47+
return HuggingFaceTokenizer.from_pretrained(
48+
model_name="sentence-transformers/all-MiniLM-L6-v2"
49+
)
50+
51+
4852
class HybridChunker(BaseChunker):
4953
r"""Chunker doing tokenization-aware refinements on top of document layout chunking.
5054
@@ -58,26 +62,40 @@ class HybridChunker(BaseChunker):
5862

5963
model_config = ConfigDict(arbitrary_types_allowed=True)
6064

61-
tokenizer: Union[PreTrainedTokenizerBase, str] = (
62-
"sentence-transformers/all-MiniLM-L6-v2"
63-
)
64-
max_tokens: int = None # type: ignore[assignment]
65+
tokenizer: BaseTokenizer = Field(default_factory=_get_default_tokenizer)
6566
merge_peers: bool = True
6667

6768
serializer_provider: BaseSerializerProvider = ChunkingSerializerProvider()
6869

69-
@model_validator(mode="after")
70-
def _patch_tokenizer_and_max_tokens(self) -> Self:
71-
self._tokenizer = (
72-
self.tokenizer
73-
if isinstance(self.tokenizer, PreTrainedTokenizerBase)
74-
else AutoTokenizer.from_pretrained(self.tokenizer)
75-
)
76-
if self.max_tokens is None:
77-
self.max_tokens = TypeAdapter(PositiveInt).validate_python(
78-
self._tokenizer.model_max_length
79-
)
80-
return self
70+
@model_validator(mode="before")
71+
@classmethod
72+
def _patch(cls, data: Any) -> Any:
73+
if isinstance(data, dict) and (tokenizer := data.get("tokenizer")):
74+
max_tokens = data.get("max_tokens")
75+
if isinstance(tokenizer, BaseTokenizer):
76+
pass
77+
else:
78+
from docling_core.transforms.chunker.tokenizer.huggingface import (
79+
HuggingFaceTokenizer,
80+
)
81+
82+
if isinstance(tokenizer, str):
83+
data["tokenizer"] = HuggingFaceTokenizer.from_pretrained(
84+
model_name=tokenizer,
85+
max_tokens=max_tokens,
86+
)
87+
else:
88+
# migrate previous HF-based tokenizers
89+
kwargs = {"tokenizer": tokenizer}
90+
if max_tokens is not None:
91+
kwargs["max_tokens"] = max_tokens
92+
data["tokenizer"] = HuggingFaceTokenizer(**kwargs)
93+
return data
94+
95+
@property
96+
def max_tokens(self) -> int:
97+
"""Get maximum number of tokens allowed."""
98+
return self.tokenizer.get_max_tokens()
8199

82100
@computed_field # type: ignore[misc]
83101
@cached_property
@@ -92,7 +110,7 @@ def _count_text_tokens(self, text: Optional[Union[str, list[str]]]):
92110
for t in text:
93111
total += self._count_text_tokens(t)
94112
return total
95-
return len(self._tokenizer.tokenize(text))
113+
return self.tokenizer.count_tokens(text=text)
96114

97115
class _ChunkLengthInfo(BaseModel):
98116
total_len: int
@@ -101,7 +119,7 @@ class _ChunkLengthInfo(BaseModel):
101119

102120
def _count_chunk_tokens(self, doc_chunk: DocChunk):
103121
ser_txt = self.contextualize(chunk=doc_chunk)
104-
return len(self._tokenizer.tokenize(text=ser_txt))
122+
return self.tokenizer.count_tokens(text=ser_txt)
105123

106124
def _doc_chunk_length(self, doc_chunk: DocChunk):
107125
text_length = self._count_text_tokens(doc_chunk.text)
@@ -198,7 +216,7 @@ def _split_using_plain_text(
198216
# captions:
199217
available_length = self.max_tokens - lengths.other_len
200218
sem_chunker = semchunk.chunkerify(
201-
self._tokenizer, chunk_size=available_length
219+
self.tokenizer.get_tokenizer(), chunk_size=available_length
202220
)
203221
if available_length <= 0:
204222
warnings.warn(
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Define the tokenizer types."""
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Define base classes for tokenization."""
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Any
5+
6+
from pydantic import BaseModel
7+
8+
9+
class BaseTokenizer(BaseModel, ABC):
10+
"""Base tokenizer class."""
11+
12+
@abstractmethod
13+
def count_tokens(self, text: str) -> int:
14+
"""Get number of tokens for given text."""
15+
...
16+
17+
@abstractmethod
18+
def get_max_tokens(self) -> int:
19+
"""Get maximum number of tokens allowed."""
20+
...
21+
22+
@abstractmethod
23+
def get_tokenizer(self) -> Any:
24+
"""Get underlying tokenizer object."""
25+
...
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""HuggingFace tokenization."""
2+
3+
import sys
4+
from os import PathLike
5+
from typing import Optional, Union
6+
7+
from pydantic import ConfigDict, PositiveInt, TypeAdapter, model_validator
8+
from typing_extensions import Self
9+
10+
from docling_core.transforms.chunker.tokenizer.base import BaseTokenizer
11+
12+
try:
13+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
14+
except ImportError:
15+
raise RuntimeError(
16+
"Module requires 'chunking' extra; to install, run: "
17+
"`pip install 'docling-core[chunking]'`"
18+
)
19+
20+
21+
class HuggingFaceTokenizer(BaseTokenizer):
22+
"""HuggingFace tokenizer."""
23+
24+
model_config = ConfigDict(arbitrary_types_allowed=True)
25+
26+
tokenizer: PreTrainedTokenizerBase
27+
max_tokens: int = None # type: ignore[assignment]
28+
29+
@model_validator(mode="after")
30+
def _patch(self) -> Self:
31+
if hasattr(self.tokenizer, "model_max_length"):
32+
model_max_tokens: PositiveInt = TypeAdapter(PositiveInt).validate_python(
33+
self.tokenizer.model_max_length
34+
)
35+
user_max_tokens = self.max_tokens or sys.maxsize
36+
self.max_tokens = min(model_max_tokens, user_max_tokens)
37+
elif self.max_tokens is None:
38+
raise ValueError(
39+
"max_tokens must be defined as model does not define model_max_length"
40+
)
41+
return self
42+
43+
def count_tokens(self, text: str):
44+
"""Get number of tokens for given text."""
45+
return len(self.tokenizer.tokenize(text=text))
46+
47+
def get_max_tokens(self):
48+
"""Get maximum number of tokens allowed."""
49+
return self.max_tokens
50+
51+
@classmethod
52+
def from_pretrained(
53+
cls,
54+
model_name: Union[str, PathLike],
55+
max_tokens: Optional[int] = None,
56+
**kwargs,
57+
) -> Self:
58+
"""Create tokenizer from model name."""
59+
my_kwargs = {
60+
"tokenizer": AutoTokenizer.from_pretrained(
61+
pretrained_model_name_or_path=model_name, **kwargs
62+
),
63+
}
64+
if max_tokens is not None:
65+
my_kwargs["max_tokens"] = max_tokens
66+
return cls(**my_kwargs)
67+
68+
def get_tokenizer(self):
69+
"""Get underlying tokenizer object."""
70+
return self.tokenizer
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""OpenAI tokenization."""
2+
3+
from pydantic import ConfigDict
4+
5+
from docling_core.transforms.chunker.hybrid_chunker import BaseTokenizer
6+
7+
try:
8+
import tiktoken
9+
except ImportError:
10+
raise RuntimeError(
11+
"Module requires 'chunking-openai' extra; to install, run: "
12+
"`pip install 'docling-core[chunking-openai]'`"
13+
)
14+
15+
16+
class OpenAITokenizer(BaseTokenizer):
17+
"""OpenAI tokenizer."""
18+
19+
model_config = ConfigDict(arbitrary_types_allowed=True)
20+
21+
tokenizer: tiktoken.Encoding
22+
max_tokens: int
23+
24+
def count_tokens(self, text: str):
25+
"""Get number of tokens for given text."""
26+
return len(self.tokenizer.encode(text=text))
27+
28+
def get_max_tokens(self):
29+
"""Get maximum number of tokens allowed."""
30+
return self.max_tokens
31+
32+
def get_tokenizer(self):
33+
"""Get underlying tokenizer object."""
34+
return self.tokenizer

0 commit comments

Comments
 (0)