Skip to content

Commit 4a8f5cd

Browse files
authored
Add alternative token-based text splitter (#816)
This does not involve a separator, and will naively chunk input text at the appropriate boundaries in token space. This is helpful if we have strict token length limits that we need to strictly follow the specified chunk size, and we can't use aggressive separators like spaces to guarantee the absence of long strings. CharacterTextSplitter will let these strings through without splitting them, which could cause overflow errors downstream. Splitting at arbitrary token boundaries is not ideal but is hopefully mitigated by having a decent overlap quantity. Also this results in chunks which has exact number of tokens desired, instead of sometimes overcounting if we concatenate shorter strings. Potentially also helps with #528.
1 parent 523ad2e commit 4a8f5cd

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

langchain/text_splitter.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,38 @@ def split_text(self, text: str) -> List[str]:
146146
return self._merge_splits(splits, self._separator)
147147

148148

149+
class TokenTextSplitter(TextSplitter):
150+
"""Implementation of splitting text that looks at tokens."""
151+
152+
def __init__(self, encoding_name: str = "gpt2", **kwargs: Any):
153+
"""Create a new TextSplitter."""
154+
super().__init__(**kwargs)
155+
try:
156+
import tiktoken
157+
except ImportError:
158+
raise ValueError(
159+
"Could not import tiktoken python package. "
160+
"This is needed in order to for TokenTextSplitter. "
161+
"Please it install it with `pip install tiktoken`."
162+
)
163+
# create a GPT-3 encoder instance
164+
self._tokenizer = tiktoken.get_encoding(encoding_name)
165+
166+
def split_text(self, text: str) -> List[str]:
167+
"""Split incoming text and return chunks."""
168+
splits = []
169+
input_ids = self._tokenizer.encode(text)
170+
start_idx = 0
171+
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
172+
chunk_ids = input_ids[start_idx:cur_idx]
173+
while start_idx < len(input_ids):
174+
splits.append(self._tokenizer.decode(chunk_ids))
175+
start_idx += self._chunk_size - self._chunk_overlap
176+
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
177+
chunk_ids = input_ids[start_idx:cur_idx]
178+
return splits
179+
180+
149181
class RecursiveCharacterTextSplitter(TextSplitter):
150182
"""Implementation of splitting text that looks at characters.
151183

tests/integration_tests/test_text_splitter.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from langchain.text_splitter import CharacterTextSplitter
5+
from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
66

77

88
def test_huggingface_type_check() -> None:
@@ -21,3 +21,21 @@ def test_huggingface_tokenizer() -> None:
2121
)
2222
output = text_splitter.split_text("foo bar")
2323
assert output == ["foo", "bar"]
24+
25+
26+
class TestTokenTextSplitter:
27+
"""Test token text splitter."""
28+
29+
def test_basic(self) -> None:
30+
"""Test no overlap."""
31+
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0)
32+
output = splitter.split_text("abcdef" * 5) # 10 token string
33+
expected_output = ["abcdefabcdefabc", "defabcdefabcdef"]
34+
assert output == expected_output
35+
36+
def test_overlap(self) -> None:
37+
"""Test with overlap."""
38+
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1)
39+
output = splitter.split_text("abcdef" * 5) # 10 token string
40+
expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"]
41+
assert output == expected_output

0 commit comments

Comments
 (0)