88from functools import cached_property
99from typing import Any , Iterable , Iterator , Optional , Union
1010
11- from pydantic import (
12- BaseModel ,
13- ConfigDict ,
14- PositiveInt ,
15- TypeAdapter ,
16- computed_field ,
17- model_validator ,
18- )
19- from typing_extensions import Self
11+ from pydantic import BaseModel , ConfigDict , Field , computed_field , model_validator
2012
2113from docling_core .transforms .chunker .hierarchical_chunker import (
2214 ChunkingSerializerProvider ,
2315)
16+ from docling_core .transforms .chunker .tokenizer .base import BaseTokenizer
2417
2518try :
2619 import semchunk
27- from transformers import AutoTokenizer , PreTrainedTokenizerBase
2820except ImportError :
2921 raise RuntimeError (
30- "Module requires 'chunking' extra; to install, run: "
31- "`pip install 'docling-core[chunking]'`"
22+ "Extra required by module: 'chunking' by default (or 'chunking-openai' if "
23+ "specifically using OpenAI tokenization); to install, run: "
24+ "`pip install 'docling-core[chunking]'` or "
25+ "`pip install 'docling-core[chunking-openai]'`"
3226 )
3327
3428from docling_core .experimental .serializer .base import (
4539from docling_core .types import DoclingDocument
4640
4741
42+ def _get_default_tokenizer ():
43+ from docling_core .transforms .chunker .tokenizer .huggingface import (
44+ HuggingFaceTokenizer ,
45+ )
46+
47+ return HuggingFaceTokenizer .from_pretrained (
48+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
49+ )
50+
51+
4852class HybridChunker (BaseChunker ):
4953 r"""Chunker doing tokenization-aware refinements on top of document layout chunking.
5054
@@ -58,26 +62,40 @@ class HybridChunker(BaseChunker):
5862
5963 model_config = ConfigDict (arbitrary_types_allowed = True )
6064
61- tokenizer : Union [PreTrainedTokenizerBase , str ] = (
62- "sentence-transformers/all-MiniLM-L6-v2"
63- )
64- max_tokens : int = None # type: ignore[assignment]
65+ tokenizer : BaseTokenizer = Field (default_factory = _get_default_tokenizer )
6566 merge_peers : bool = True
6667
6768 serializer_provider : BaseSerializerProvider = ChunkingSerializerProvider ()
6869
69- @model_validator (mode = "after" )
70- def _patch_tokenizer_and_max_tokens (self ) -> Self :
71- self ._tokenizer = (
72- self .tokenizer
73- if isinstance (self .tokenizer , PreTrainedTokenizerBase )
74- else AutoTokenizer .from_pretrained (self .tokenizer )
75- )
76- if self .max_tokens is None :
77- self .max_tokens = TypeAdapter (PositiveInt ).validate_python (
78- self ._tokenizer .model_max_length
79- )
80- return self
70+ @model_validator (mode = "before" )
71+ @classmethod
72+ def _patch (cls , data : Any ) -> Any :
73+ if isinstance (data , dict ) and (tokenizer := data .get ("tokenizer" )):
74+ max_tokens = data .get ("max_tokens" )
75+ if isinstance (tokenizer , BaseTokenizer ):
76+ pass
77+ else :
78+ from docling_core .transforms .chunker .tokenizer .huggingface import (
79+ HuggingFaceTokenizer ,
80+ )
81+
82+ if isinstance (tokenizer , str ):
83+ data ["tokenizer" ] = HuggingFaceTokenizer .from_pretrained (
84+ model_name = tokenizer ,
85+ max_tokens = max_tokens ,
86+ )
87+ else :
88+ # migrate previous HF-based tokenizers
89+ kwargs = {"tokenizer" : tokenizer }
90+ if max_tokens is not None :
91+ kwargs ["max_tokens" ] = max_tokens
92+ data ["tokenizer" ] = HuggingFaceTokenizer (** kwargs )
93+ return data
94+
95+ @property
96+ def max_tokens (self ) -> int :
97+ """Get maximum number of tokens allowed."""
98+ return self .tokenizer .get_max_tokens ()
8199
82100 @computed_field # type: ignore[misc]
83101 @cached_property
@@ -92,7 +110,7 @@ def _count_text_tokens(self, text: Optional[Union[str, list[str]]]):
92110 for t in text :
93111 total += self ._count_text_tokens (t )
94112 return total
95- return len ( self ._tokenizer . tokenize (text ) )
113+ return self .tokenizer . count_tokens (text = text )
96114
97115 class _ChunkLengthInfo (BaseModel ):
98116 total_len : int
@@ -101,7 +119,7 @@ class _ChunkLengthInfo(BaseModel):
101119
102120 def _count_chunk_tokens (self , doc_chunk : DocChunk ):
103121 ser_txt = self .contextualize (chunk = doc_chunk )
104- return len ( self ._tokenizer . tokenize (text = ser_txt ) )
122+ return self .tokenizer . count_tokens (text = ser_txt )
105123
106124 def _doc_chunk_length (self , doc_chunk : DocChunk ):
107125 text_length = self ._count_text_tokens (doc_chunk .text )
@@ -198,7 +216,7 @@ def _split_using_plain_text(
198216 # captions:
199217 available_length = self .max_tokens - lengths .other_len
200218 sem_chunker = semchunk .chunkerify (
201- self ._tokenizer , chunk_size = available_length
219+ self .tokenizer . get_tokenizer () , chunk_size = available_length
202220 )
203221 if available_length <= 0 :
204222 warnings .warn (
0 commit comments