2
2
import re
3
3
from abc import ABC
4
4
from collections .abc import Generator
5
+ from dataclasses import dataclass , field
5
6
from typing import Optional
6
7
7
8
import tiktoken
@@ -84,6 +85,107 @@ def split_pages(self, pages: list[Page]) -> Generator[SplitPage, None, None]:
84
85
DEFAULT_SECTION_LENGTH = 1000 # Roughly 400-500 tokens for English
85
86
86
87
88
+ def _safe_concat (a : str , b : str ) -> str :
89
+ """Concatenate two non-empty segments, inserting a space only when both sides
90
+ end/start with alphanumerics and no natural boundary exists.
91
+
92
+ Rules:
93
+ - Both input strings are expected to be non-empty
94
+ - Preserve existing whitespace if either side already provides a boundary.
95
+ - Do not insert a space after a closing HTML tag marker '>'.
96
+ - If both boundary characters are alphanumeric, insert a single space.
97
+ - Otherwise concatenate directly.
98
+ """
99
+ assert a and b , "_safe_concat expects non-empty strings"
100
+ a_last = a [- 1 ]
101
+ b_first = b [0 ]
102
+ if a_last .isspace () or b_first .isspace (): # pre-existing boundary
103
+ return a + b
104
+ if a_last == ">" : # HTML tag end acts as a boundary
105
+ return a + b
106
+ if a_last .isalnum () and b_first .isalnum (): # need explicit separator
107
+ return a + " " + b
108
+ return a + b
109
+
110
+
111
+ def _normalize_chunk (text : str , max_chars : int ) -> str :
112
+ """Normalize a non-figure chunk that may slightly exceed max_chars.
113
+
114
+ Allows overflow for any chunk containing a <figure ...> tag (figures are atomic),
115
+ trims leading spaces if they alone cause minor overflow, and as a final step
116
+ removes a trailing space/newline when within a small tolerance (<=3 chars over).
117
+ """
118
+ lower = text .lower ()
119
+ if "<figure" in lower :
120
+ return text # never trim figure chunks
121
+ if len (text ) <= max_chars :
122
+ return text
123
+ trimmed = text
124
+ while trimmed .startswith (" " ) and len (trimmed ) > max_chars :
125
+ trimmed = trimmed [1 :]
126
+ if len (trimmed ) > max_chars and len (trimmed ) <= max_chars + 3 :
127
+ if trimmed .endswith (" " ) or trimmed .endswith ("\n " ):
128
+ trimmed = trimmed .rstrip ()
129
+ return trimmed
130
+
131
+
132
+ @dataclass
133
+ class _ChunkBuilder :
134
+ """Accumulates sentence-like units for a single page until size limits are reached.
135
+
136
+ Responsibilities:
137
+ - Track appended text fragments and their approximate token length.
138
+ - Decide if a new unit can be added without exceeding character or token thresholds.
139
+ - Flush accumulated content into an output list as a `SplitPage`.
140
+ - Allow a figure block to be force-appended (even if it overflows) so that headings + figure stay together.
141
+
142
+ Notes:
143
+ - Character limit is soft (exact enforcement + later normalization); token limit is hard.
144
+ - Token counts are computed by the caller and passed to `add`; this class stays agnostic of the encoder.
145
+ """
146
+
147
+ page_num : int
148
+ max_chars : int
149
+ max_tokens : int
150
+ parts : list [str ] = field (default_factory = list )
151
+ token_len : int = 0
152
+
153
+ def can_fit (self , text : str , token_count : int ) -> bool :
154
+ if not self .parts : # always allow first unit
155
+ return token_count <= self .max_tokens and len (text ) <= self .max_chars
156
+ # Character + token constraints
157
+ return (len ("" .join (self .parts )) + len (text ) <= self .max_chars ) and (
158
+ self .token_len + token_count <= self .max_tokens
159
+ )
160
+
161
+ def add (self , text : str , token_count : int ) -> bool :
162
+ if not self .can_fit (text , token_count ):
163
+ return False
164
+ self .parts .append (text )
165
+ self .token_len += token_count
166
+ return True
167
+
168
+ def force_append (self , text : str ):
169
+ self .parts .append (text )
170
+
171
+ def flush_into (self , out : list [SplitPage ]):
172
+ if self .parts :
173
+ chunk = "" .join (self .parts )
174
+ if chunk .strip ():
175
+ out .append (SplitPage (page_num = self .page_num , text = chunk ))
176
+ self .parts .clear ()
177
+ self .token_len = 0
178
+
179
+ # Convenience helpers for readability at call sites
180
+ def has_content (self ) -> bool :
181
+ return bool (self .parts )
182
+
183
+ def append_figure_and_flush (self , figure_text : str , out : list [SplitPage ]):
184
+ """Append a figure (allowed to overflow) to current accumulation and flush in one step."""
185
+ self .force_append (figure_text )
186
+ self .flush_into (out )
187
+
188
+
87
189
class SentenceTextSplitter (TextSplitter ):
88
190
"""
89
191
Class that splits pages into smaller chunks. This is required because embedding models may not be able to analyze an entire page at once
@@ -203,23 +305,6 @@ def split_pages(self, pages: list[Page]) -> Generator[SplitPage, None, None]:
203
305
if not raw .strip ():
204
306
continue
205
307
206
- def _safe_concat (a : str , b : str ) -> str :
207
- """Concatenate two non-empty segments, inserting a space only when both sides
208
- end/start with alphanumerics and no natural boundary exists. (Empty inputs are
209
- never passed here by construction.)"""
210
- a_last = a [- 1 ]
211
- b_first = b [0 ]
212
- # If either already has whitespace boundary, just concat
213
- if a_last .isspace () or b_first .isspace ():
214
- return a + b
215
- # If a ends with '>' (HTML tag) we assume boundary is intentional.
216
- if a_last == ">" :
217
- return a + b
218
- # If both alnum, insert space.
219
- if a_last .isalnum () and b_first .isalnum ():
220
- return a + " " + b
221
- return a + b
222
-
223
308
# 1. Build ordered list of blocks: (type, text)
224
309
blocks : list [tuple [str , str ]] = []
225
310
last = 0
@@ -234,25 +319,18 @@ def _safe_concat(a: str, b: str) -> str:
234
319
# Accumulated chunks for this page
235
320
page_chunks : list [SplitPage ] = []
236
321
237
- # Builder state for text accumulation
238
- builder : list [str ] = []
239
- builder_token_len = 0
240
-
241
- def flush_builder_to_page ():
242
- nonlocal builder , builder_token_len
243
- if builder :
244
- text_chunk = "" .join (builder )
245
- if text_chunk .strip ():
246
- page_chunks .append (SplitPage (page_num = page .page_num , text = text_chunk ))
247
- builder = []
248
- builder_token_len = 0
322
+ # Builder encapsulates accumulation logic
323
+ builder = _ChunkBuilder (
324
+ page_num = page .page_num ,
325
+ max_chars = self .max_section_length ,
326
+ max_tokens = self .max_tokens_per_section ,
327
+ )
249
328
250
329
for btype , btext in blocks :
251
330
if btype == "figure" :
252
- if builder :
331
+ if builder . has_content () :
253
332
# Append figure to existing text (allow overflow) and flush
254
- builder .append (btext )
255
- flush_builder_to_page ()
333
+ builder .append_figure_and_flush (btext , page_chunks )
256
334
else :
257
335
# Emit figure standalone
258
336
if btext .strip ():
@@ -274,21 +352,19 @@ def flush_builder_to_page():
274
352
unit_tokens = len (bpe .encode (unit ))
275
353
# If a single unit itself exceeds token limit (rare, very long sentence), split it directly
276
354
if unit_tokens > self .max_tokens_per_section :
277
- flush_builder_to_page ( )
355
+ builder . flush_into ( page_chunks )
278
356
for sp in self .split_page_by_max_tokens (page .page_num , unit , allow_figure_processing = False ):
279
357
page_chunks .append (sp )
280
358
continue
281
- if builder and (
282
- len ("" .join (builder )) + len (unit ) > self .max_section_length
283
- or builder_token_len + unit_tokens > self .max_tokens_per_section
284
- ):
285
- # Flush current builder before starting new one with this unit
286
- flush_builder_to_page ()
287
- builder .append (unit )
288
- builder_token_len += unit_tokens
359
+ if not builder .add (unit , unit_tokens ):
360
+ # Flush and retry (guaranteed to fit because unit_tokens <= limit)
361
+ builder .flush_into (page_chunks )
362
+ added = builder .add (unit , unit_tokens )
363
+ if not added : # defensive (should not happen)
364
+ page_chunks .append (SplitPage (page_num = page .page_num , text = unit ))
289
365
290
366
# Flush any trailing builder content
291
- flush_builder_to_page ( )
367
+ builder . flush_into ( page_chunks )
292
368
293
369
# Attempt cross-page merge with previous_chunk (look-behind) if semantic continuation
294
370
if previous_chunk and page_chunks :
@@ -379,27 +455,14 @@ def fits(candidate: str) -> bool:
379
455
380
456
# Normalize chunks (non-figure) that barely exceed char limit due to added boundary space
381
457
max_chars = int (self .max_section_length * 1.2 )
382
-
383
- def _normalize (text : str ) -> str :
384
- lower = text .lower ()
385
- if "<figure" in lower :
386
- return text # allow overflow for figures
387
- if len (text ) <= max_chars :
388
- return text
389
- # Trim leading spaces first (most common cause after safe concat)
390
- trimmed = text
391
- while trimmed .startswith (" " ) and len (trimmed ) > max_chars :
392
- trimmed = trimmed [1 :]
393
- # As a fallback, if still barely over (<= max_chars+3), try removing a trailing space/newline
394
- if len (trimmed ) > max_chars and len (trimmed ) <= max_chars + 3 :
395
- if trimmed .endswith (" " ) or trimmed .endswith ("\n " ):
396
- trimmed = trimmed .rstrip ()
397
- return trimmed
398
-
399
458
if previous_chunk :
400
- previous_chunk = SplitPage (page_num = previous_chunk .page_num , text = _normalize (previous_chunk .text ))
459
+ previous_chunk = SplitPage (
460
+ page_num = previous_chunk .page_num , text = _normalize_chunk (previous_chunk .text , max_chars )
461
+ )
401
462
if page_chunks :
402
- page_chunks = [SplitPage (page_num = sp .page_num , text = _normalize (sp .text )) for sp in page_chunks ]
463
+ page_chunks = [
464
+ SplitPage (page_num = sp .page_num , text = _normalize_chunk (sp .text , max_chars )) for sp in page_chunks
465
+ ]
403
466
404
467
# Emit previous_chunk now that merge opportunity considered
405
468
if previous_chunk :
0 commit comments