Skip to content

Commit 7cebf13

Browse files
committed
Add some helper functions and modules to improve code clarity for textsplitter
1 parent 90420c5 commit 7cebf13

File tree

1 file changed

+124
-61
lines changed

1 file changed

+124
-61
lines changed

app/backend/prepdocslib/textsplitter.py

Lines changed: 124 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import re
33
from abc import ABC
44
from collections.abc import Generator
5+
from dataclasses import dataclass, field
56
from typing import Optional
67

78
import tiktoken
@@ -84,6 +85,107 @@ def split_pages(self, pages: list[Page]) -> Generator[SplitPage, None, None]:
8485
DEFAULT_SECTION_LENGTH = 1000 # Roughly 400-500 tokens for English
8586

8687

88+
def _safe_concat(a: str, b: str) -> str:
89+
"""Concatenate two non-empty segments, inserting a space only when both sides
90+
end/start with alphanumerics and no natural boundary exists.
91+
92+
Rules:
93+
- Both input strings are expected to be non-empty
94+
- Preserve existing whitespace if either side already provides a boundary.
95+
- Do not insert a space after a closing HTML tag marker '>'.
96+
- If both boundary characters are alphanumeric, insert a single space.
97+
- Otherwise concatenate directly.
98+
"""
99+
assert a and b, "_safe_concat expects non-empty strings"
100+
a_last = a[-1]
101+
b_first = b[0]
102+
if a_last.isspace() or b_first.isspace(): # pre-existing boundary
103+
return a + b
104+
if a_last == ">": # HTML tag end acts as a boundary
105+
return a + b
106+
if a_last.isalnum() and b_first.isalnum(): # need explicit separator
107+
return a + " " + b
108+
return a + b
109+
110+
111+
def _normalize_chunk(text: str, max_chars: int) -> str:
112+
"""Normalize a non-figure chunk that may slightly exceed max_chars.
113+
114+
Allows overflow for any chunk containing a <figure ...> tag (figures are atomic),
115+
trims leading spaces if they alone cause minor overflow, and as a final step
116+
removes a trailing space/newline when within a small tolerance (<=3 chars over).
117+
"""
118+
lower = text.lower()
119+
if "<figure" in lower:
120+
return text # never trim figure chunks
121+
if len(text) <= max_chars:
122+
return text
123+
trimmed = text
124+
while trimmed.startswith(" ") and len(trimmed) > max_chars:
125+
trimmed = trimmed[1:]
126+
if len(trimmed) > max_chars and len(trimmed) <= max_chars + 3:
127+
if trimmed.endswith(" ") or trimmed.endswith("\n"):
128+
trimmed = trimmed.rstrip()
129+
return trimmed
130+
131+
132+
@dataclass
133+
class _ChunkBuilder:
134+
"""Accumulates sentence-like units for a single page until size limits are reached.
135+
136+
Responsibilities:
137+
- Track appended text fragments and their approximate token length.
138+
- Decide if a new unit can be added without exceeding character or token thresholds.
139+
- Flush accumulated content into an output list as a `SplitPage`.
140+
- Allow a figure block to be force-appended (even if it overflows) so that headings + figure stay together.
141+
142+
Notes:
143+
- Character limit is soft (exact enforcement + later normalization); token limit is hard.
144+
- Token counts are computed by the caller and passed to `add`; this class stays agnostic of the encoder.
145+
"""
146+
147+
page_num: int
148+
max_chars: int
149+
max_tokens: int
150+
parts: list[str] = field(default_factory=list)
151+
token_len: int = 0
152+
153+
def can_fit(self, text: str, token_count: int) -> bool:
154+
if not self.parts: # always allow first unit
155+
return token_count <= self.max_tokens and len(text) <= self.max_chars
156+
# Character + token constraints
157+
return (len("".join(self.parts)) + len(text) <= self.max_chars) and (
158+
self.token_len + token_count <= self.max_tokens
159+
)
160+
161+
def add(self, text: str, token_count: int) -> bool:
162+
if not self.can_fit(text, token_count):
163+
return False
164+
self.parts.append(text)
165+
self.token_len += token_count
166+
return True
167+
168+
def force_append(self, text: str):
169+
self.parts.append(text)
170+
171+
def flush_into(self, out: list[SplitPage]):
172+
if self.parts:
173+
chunk = "".join(self.parts)
174+
if chunk.strip():
175+
out.append(SplitPage(page_num=self.page_num, text=chunk))
176+
self.parts.clear()
177+
self.token_len = 0
178+
179+
# Convenience helpers for readability at call sites
180+
def has_content(self) -> bool:
181+
return bool(self.parts)
182+
183+
def append_figure_and_flush(self, figure_text: str, out: list[SplitPage]):
184+
"""Append a figure (allowed to overflow) to current accumulation and flush in one step."""
185+
self.force_append(figure_text)
186+
self.flush_into(out)
187+
188+
87189
class SentenceTextSplitter(TextSplitter):
88190
"""
89191
Class that splits pages into smaller chunks. This is required because embedding models may not be able to analyze an entire page at once
@@ -203,23 +305,6 @@ def split_pages(self, pages: list[Page]) -> Generator[SplitPage, None, None]:
203305
if not raw.strip():
204306
continue
205307

206-
def _safe_concat(a: str, b: str) -> str:
207-
"""Concatenate two non-empty segments, inserting a space only when both sides
208-
end/start with alphanumerics and no natural boundary exists. (Empty inputs are
209-
never passed here by construction.)"""
210-
a_last = a[-1]
211-
b_first = b[0]
212-
# If either already has whitespace boundary, just concat
213-
if a_last.isspace() or b_first.isspace():
214-
return a + b
215-
# If a ends with '>' (HTML tag) we assume boundary is intentional.
216-
if a_last == ">":
217-
return a + b
218-
# If both alnum, insert space.
219-
if a_last.isalnum() and b_first.isalnum():
220-
return a + " " + b
221-
return a + b
222-
223308
# 1. Build ordered list of blocks: (type, text)
224309
blocks: list[tuple[str, str]] = []
225310
last = 0
@@ -234,25 +319,18 @@ def _safe_concat(a: str, b: str) -> str:
234319
# Accumulated chunks for this page
235320
page_chunks: list[SplitPage] = []
236321

237-
# Builder state for text accumulation
238-
builder: list[str] = []
239-
builder_token_len = 0
240-
241-
def flush_builder_to_page():
242-
nonlocal builder, builder_token_len
243-
if builder:
244-
text_chunk = "".join(builder)
245-
if text_chunk.strip():
246-
page_chunks.append(SplitPage(page_num=page.page_num, text=text_chunk))
247-
builder = []
248-
builder_token_len = 0
322+
# Builder encapsulates accumulation logic
323+
builder = _ChunkBuilder(
324+
page_num=page.page_num,
325+
max_chars=self.max_section_length,
326+
max_tokens=self.max_tokens_per_section,
327+
)
249328

250329
for btype, btext in blocks:
251330
if btype == "figure":
252-
if builder:
331+
if builder.has_content():
253332
# Append figure to existing text (allow overflow) and flush
254-
builder.append(btext)
255-
flush_builder_to_page()
333+
builder.append_figure_and_flush(btext, page_chunks)
256334
else:
257335
# Emit figure standalone
258336
if btext.strip():
@@ -274,21 +352,19 @@ def flush_builder_to_page():
274352
unit_tokens = len(bpe.encode(unit))
275353
# If a single unit itself exceeds token limit (rare, very long sentence), split it directly
276354
if unit_tokens > self.max_tokens_per_section:
277-
flush_builder_to_page()
355+
builder.flush_into(page_chunks)
278356
for sp in self.split_page_by_max_tokens(page.page_num, unit, allow_figure_processing=False):
279357
page_chunks.append(sp)
280358
continue
281-
if builder and (
282-
len("".join(builder)) + len(unit) > self.max_section_length
283-
or builder_token_len + unit_tokens > self.max_tokens_per_section
284-
):
285-
# Flush current builder before starting new one with this unit
286-
flush_builder_to_page()
287-
builder.append(unit)
288-
builder_token_len += unit_tokens
359+
if not builder.add(unit, unit_tokens):
360+
# Flush and retry (guaranteed to fit because unit_tokens <= limit)
361+
builder.flush_into(page_chunks)
362+
added = builder.add(unit, unit_tokens)
363+
if not added: # defensive (should not happen)
364+
page_chunks.append(SplitPage(page_num=page.page_num, text=unit))
289365

290366
# Flush any trailing builder content
291-
flush_builder_to_page()
367+
builder.flush_into(page_chunks)
292368

293369
# Attempt cross-page merge with previous_chunk (look-behind) if semantic continuation
294370
if previous_chunk and page_chunks:
@@ -379,27 +455,14 @@ def fits(candidate: str) -> bool:
379455

380456
# Normalize chunks (non-figure) that barely exceed char limit due to added boundary space
381457
max_chars = int(self.max_section_length * 1.2)
382-
383-
def _normalize(text: str) -> str:
384-
lower = text.lower()
385-
if "<figure" in lower:
386-
return text # allow overflow for figures
387-
if len(text) <= max_chars:
388-
return text
389-
# Trim leading spaces first (most common cause after safe concat)
390-
trimmed = text
391-
while trimmed.startswith(" ") and len(trimmed) > max_chars:
392-
trimmed = trimmed[1:]
393-
# As a fallback, if still barely over (<= max_chars+3), try removing a trailing space/newline
394-
if len(trimmed) > max_chars and len(trimmed) <= max_chars + 3:
395-
if trimmed.endswith(" ") or trimmed.endswith("\n"):
396-
trimmed = trimmed.rstrip()
397-
return trimmed
398-
399458
if previous_chunk:
400-
previous_chunk = SplitPage(page_num=previous_chunk.page_num, text=_normalize(previous_chunk.text))
459+
previous_chunk = SplitPage(
460+
page_num=previous_chunk.page_num, text=_normalize_chunk(previous_chunk.text, max_chars)
461+
)
401462
if page_chunks:
402-
page_chunks = [SplitPage(page_num=sp.page_num, text=_normalize(sp.text)) for sp in page_chunks]
463+
page_chunks = [
464+
SplitPage(page_num=sp.page_num, text=_normalize_chunk(sp.text, max_chars)) for sp in page_chunks
465+
]
403466

404467
# Emit previous_chunk now that merge opportunity considered
405468
if previous_chunk:

0 commit comments

Comments
 (0)