Skip to content

Commit 2d3b4fd

Browse files
Add unit tests
1 parent d71aad0 commit 2d3b4fd

File tree

1 file changed

+137
-1
lines changed

1 file changed

+137
-1
lines changed

tests/unit/experimental/components/text_splitters/test_fixed_size_splitter.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import pytest
1818
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
19-
FixedSizeSplitter,
19+
FixedSizeSplitter, _adjust_chunk_start, _adjust_chunk_end,
2020
)
2121
from neo4j_graphrag.experimental.components.types import TextChunk
2222

@@ -78,3 +78,139 @@ def test_invalid_chunk_overlap() -> None:
7878
with pytest.raises(ValueError) as excinfo:
7979
FixedSizeSplitter(5, 5)
8080
assert "chunk_overlap must be strictly less than chunk_size" in str(excinfo)
81+
82+
83+
def test_invalid_chunk_size() -> None:
84+
with pytest.raises(ValueError) as excinfo:
85+
FixedSizeSplitter(0, 0)
86+
assert "chunk_size must be strictly greater than 0" in str(excinfo)
87+
88+
89+
@pytest.mark.parametrize(
90+
"text, approximate_start, expected_start",
91+
[
92+
# Case: approximate_start is at word boundary already
93+
("Hello World", 6, 6),
94+
# Case: approximate_start is at a whitespace already
95+
("Hello World", 5, 5),
96+
# Case: approximate_start is at the middle of word and no whitespace is found
97+
("Hello World", 2, 2),
98+
# Case: approximate_start is at the middle of a word
99+
("Hello World", 8, 6),
100+
# Case: approximate_start = 0
101+
("Hello World", 0, 0),
102+
],
103+
)
104+
def test_adjust_chunk_start(text, approximate_start, expected_start):
105+
"""
106+
Test that the _adjust_chunk_start function correctly shifts
107+
the start index to avoid breaking words, unless no whitespace is found.
108+
"""
109+
result = _adjust_chunk_start(text, approximate_start)
110+
assert result == expected_start
111+
112+
113+
@pytest.mark.parametrize(
114+
"text, start, approximate_end, expected_end",
115+
[
116+
# Case: approximate_end is at word boundary already
117+
("Hello World", 0, 5, 5),
118+
# Case: approximate_end is at the middle of a word
119+
("Hello World", 0, 8, 6),
120+
# Case: approximate_end is at the middle of word and no whitespace is found
121+
("Hello World", 0, 3, 3),
122+
# Case: adjusted_end == start => fallback to approximate_end
123+
("Hello World", 6, 7, 7),
124+
# Case: end>=len(text)
125+
("Hello World", 6, 15, 15),
126+
],
127+
)
128+
def test_adjust_chunk_end(text, start, approximate_end, expected_end):
129+
"""
130+
Test that the _adjust_chunk_end function correctly shifts
131+
the end index to avoid breaking words, unless no whitespace is found.
132+
"""
133+
result = _adjust_chunk_end(text, start, approximate_end)
134+
assert result == expected_end
135+
136+
137+
@pytest.mark.asyncio
138+
@pytest.mark.parametrize(
139+
"text, chunk_size, chunk_overlap, approximate, expected_chunks",
140+
[
141+
# Case: approximate fixed size splitting
142+
(
143+
"Hello World, this is a test message.",
144+
10,
145+
2,
146+
True,
147+
[
148+
"Hello ",
149+
"World, ",
150+
"this is a ",
151+
"a test ",
152+
"message."
153+
],
154+
),
155+
# Case: fixed size splitting
156+
(
157+
"Hello World, this is a test message.",
158+
10,
159+
2,
160+
False,
161+
[
162+
"Hello Worl",
163+
"rld, this ",
164+
"s is a tes",
165+
"est messag",
166+
"age."
167+
],
168+
),
169+
# Case: short text => only one chunk
170+
(
171+
"Short text",
172+
20,
173+
5,
174+
True,
175+
["Short text"],
176+
),
177+
# Case: short text => only one chunk
178+
(
179+
"Short text",
180+
12,
181+
4,
182+
True,
183+
["Short text"],
184+
),
185+
# Case: text with no spaces
186+
(
187+
"1234567890",
188+
5,
189+
1,
190+
True,
191+
["12345", "56789", "90"],
192+
),
193+
],
194+
)
195+
async def test_fixed_size_splitter_run(
196+
text, chunk_size, chunk_overlap, approximate, expected_chunks
197+
):
198+
"""
199+
Test that 'FixedSizeSplitter.run' returns the expected chunks
200+
for different configurations.
201+
"""
202+
splitter = FixedSizeSplitter(
203+
chunk_size=chunk_size,
204+
chunk_overlap=chunk_overlap,
205+
approximate=approximate,
206+
)
207+
text_chunks = await splitter.run(text)
208+
209+
# Verify number of chunks
210+
assert len(text_chunks.chunks) == len(expected_chunks)
211+
212+
# Verify content of each chunk
213+
for i, expected_text in enumerate(expected_chunks):
214+
assert text_chunks.chunks[i].text == expected_text
215+
assert isinstance(text_chunks.chunks[i], TextChunk)
216+
assert text_chunks.chunks[i].index == i

0 commit comments

Comments
 (0)