|
16 | 16 |
|
17 | 17 | import pytest |
18 | 18 | from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( |
19 | | - FixedSizeSplitter, |
| 19 | + FixedSizeSplitter, _adjust_chunk_start, _adjust_chunk_end, |
20 | 20 | ) |
21 | 21 | from neo4j_graphrag.experimental.components.types import TextChunk |
22 | 22 |
|
@@ -78,3 +78,139 @@ def test_invalid_chunk_overlap() -> None: |
78 | 78 | with pytest.raises(ValueError) as excinfo: |
79 | 79 | FixedSizeSplitter(5, 5) |
80 | 80 | assert "chunk_overlap must be strictly less than chunk_size" in str(excinfo) |
| 81 | + |
| 82 | + |
| 83 | +def test_invalid_chunk_size() -> None: |
| 84 | + with pytest.raises(ValueError) as excinfo: |
| 85 | + FixedSizeSplitter(0, 0) |
| 86 | + assert "chunk_size must be strictly greater than 0" in str(excinfo) |
| 87 | + |
| 88 | + |
| 89 | +@pytest.mark.parametrize( |
| 90 | + "text, approximate_start, expected_start", |
| 91 | + [ |
| 92 | + # Case: approximate_start is at word boundary already |
| 93 | + ("Hello World", 6, 6), |
| 94 | + # Case: approximate_start is at a whitespace already |
| 95 | + ("Hello World", 5, 5), |
| 96 | + # Case: approximate_start is at the middle of word and no whitespace is found |
| 97 | + ("Hello World", 2, 2), |
| 98 | + # Case: approximate_start is at the middle of a word |
| 99 | + ("Hello World", 8, 6), |
| 100 | + # Case: approximate_start = 0 |
| 101 | + ("Hello World", 0, 0), |
| 102 | + ], |
| 103 | +) |
| 104 | +def test_adjust_chunk_start(text, approximate_start, expected_start): |
| 105 | + """ |
| 106 | + Test that the _adjust_chunk_start function correctly shifts |
| 107 | + the start index to avoid breaking words, unless no whitespace is found. |
| 108 | + """ |
| 109 | + result = _adjust_chunk_start(text, approximate_start) |
| 110 | + assert result == expected_start |
| 111 | + |
| 112 | + |
| 113 | +@pytest.mark.parametrize( |
| 114 | + "text, start, approximate_end, expected_end", |
| 115 | + [ |
| 116 | + # Case: approximate_end is at word boundary already |
| 117 | + ("Hello World", 0, 5, 5), |
| 118 | + # Case: approximate_end is at the middle of a word |
| 119 | + ("Hello World", 0, 8, 6), |
| 120 | + # Case: approximate_end is at the middle of word and no whitespace is found |
| 121 | + ("Hello World", 0, 3, 3), |
| 122 | + # Case: adjusted_end == start => fallback to approximate_end |
| 123 | + ("Hello World", 6, 7, 7), |
| 124 | + # Case: end>=len(text) |
| 125 | + ("Hello World", 6, 15, 15), |
| 126 | + ], |
| 127 | +) |
| 128 | +def test_adjust_chunk_end(text, start, approximate_end, expected_end): |
| 129 | + """ |
| 130 | + Test that the _adjust_chunk_end function correctly shifts |
| 131 | + the end index to avoid breaking words, unless no whitespace is found. |
| 132 | + """ |
| 133 | + result = _adjust_chunk_end(text, start, approximate_end) |
| 134 | + assert result == expected_end |
| 135 | + |
| 136 | + |
| 137 | +@pytest.mark.asyncio |
| 138 | +@pytest.mark.parametrize( |
| 139 | + "text, chunk_size, chunk_overlap, approximate, expected_chunks", |
| 140 | + [ |
| 141 | + # Case: approximate fixed size splitting |
| 142 | + ( |
| 143 | + "Hello World, this is a test message.", |
| 144 | + 10, |
| 145 | + 2, |
| 146 | + True, |
| 147 | + [ |
| 148 | + "Hello ", |
| 149 | + "World, ", |
| 150 | + "this is a ", |
| 151 | + "a test ", |
| 152 | + "message." |
| 153 | + ], |
| 154 | + ), |
| 155 | + # Case: fixed size splitting |
| 156 | + ( |
| 157 | + "Hello World, this is a test message.", |
| 158 | + 10, |
| 159 | + 2, |
| 160 | + False, |
| 161 | + [ |
| 162 | + "Hello Worl", |
| 163 | + "rld, this ", |
| 164 | + "s is a tes", |
| 165 | + "est messag", |
| 166 | + "age." |
| 167 | + ], |
| 168 | + ), |
| 169 | + # Case: short text => only one chunk |
| 170 | + ( |
| 171 | + "Short text", |
| 172 | + 20, |
| 173 | + 5, |
| 174 | + True, |
| 175 | + ["Short text"], |
| 176 | + ), |
| 177 | + # Case: short text => only one chunk |
| 178 | + ( |
| 179 | + "Short text", |
| 180 | + 12, |
| 181 | + 4, |
| 182 | + True, |
| 183 | + ["Short text"], |
| 184 | + ), |
| 185 | + # Case: text with no spaces |
| 186 | + ( |
| 187 | + "1234567890", |
| 188 | + 5, |
| 189 | + 1, |
| 190 | + True, |
| 191 | + ["12345", "56789", "90"], |
| 192 | + ), |
| 193 | + ], |
| 194 | +) |
| 195 | +async def test_fixed_size_splitter_run( |
| 196 | + text, chunk_size, chunk_overlap, approximate, expected_chunks |
| 197 | +): |
| 198 | + """ |
| 199 | + Test that 'FixedSizeSplitter.run' returns the expected chunks |
| 200 | + for different configurations. |
| 201 | + """ |
| 202 | + splitter = FixedSizeSplitter( |
| 203 | + chunk_size=chunk_size, |
| 204 | + chunk_overlap=chunk_overlap, |
| 205 | + approximate=approximate, |
| 206 | + ) |
| 207 | + text_chunks = await splitter.run(text) |
| 208 | + |
| 209 | + # Verify number of chunks |
| 210 | + assert len(text_chunks.chunks) == len(expected_chunks) |
| 211 | + |
| 212 | + # Verify content of each chunk |
| 213 | + for i, expected_text in enumerate(expected_chunks): |
| 214 | + assert text_chunks.chunks[i].text == expected_text |
| 215 | + assert isinstance(text_chunks.chunks[i], TextChunk) |
| 216 | + assert text_chunks.chunks[i].index == i |
0 commit comments