Skip to content

Commit 45b370f

Browse files
authored
Merge pull request #8 from gmanvel/main
Performance Optimizations for TokenTextChunker
2 parents 56b6389 + 22e277d commit 45b370f

File tree

3 files changed

+176
-15
lines changed

3 files changed

+176
-15
lines changed

benchmarks/ManagedCode.GraphRag.Benchmarks/Chunking/TokenTextChunkerBenchmarks.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
namespace ManagedCode.GraphRag.Benchmarks.Chunking;
66

77
[MemoryDiagnoser]
8+
[HideColumns("Error", "StdDev", "RatioSD")]
89
public class TokenTextChunkerBenchmarks
910
{
1011
private TokenTextChunker _chunker = null!;
@@ -36,7 +37,7 @@ public void Setup()
3637
_largeDocument = new[] { new ChunkSlice("doc1", GeneratePlainTextDocument(1_000_000)) };
3738
}
3839

39-
[Benchmark]
40+
[Benchmark(Baseline = true)]
4041
public IReadOnlyList<TextChunk> ChunkSmallDocument()
4142
{
4243
return _chunker.Chunk(_smallDocument, _config);

src/ManagedCode.GraphRag/Chunking/TokenTextChunker.cs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using System.Buffers;
2+
using System.Runtime.InteropServices;
13
using GraphRag.Config;
24
using GraphRag.Tokenization;
35

@@ -12,48 +14,62 @@ public IReadOnlyList<TextChunk> Chunk(IReadOnlyList<ChunkSlice> slices, Chunking
1214

1315
if (slices.Count == 0)
1416
{
15-
return Array.Empty<TextChunk>();
17+
return [];
1618
}
1719

1820
var tokenizer = TokenizerRegistry.GetTokenizer(config.EncodingModel);
1921
var flattened = new List<(int SliceIndex, int Token)>();
22+
2023
for (var index = 0; index < slices.Count; index++)
2124
{
2225
var slice = slices[index];
23-
var encoded = tokenizer.EncodeToIds(slice.Text);
24-
foreach (var token in encoded)
26+
var encoded = tokenizer.EncodeToIds(slice.Text.AsSpan());
27+
for (var i = 0; i < encoded.Count; i++)
2528
{
29+
var token = encoded[i];
2630
flattened.Add((index, token));
2731
}
2832
}
2933

3034
if (flattened.Count == 0)
3135
{
32-
return Array.Empty<TextChunk>();
36+
return [];
3337
}
3438

3539
var chunkSize = Math.Max(1, config.Size);
3640
var overlap = Math.Clamp(config.Overlap, 0, chunkSize - 1);
37-
var results = new List<TextChunk>();
41+
42+
var step = chunkSize - overlap;
43+
var estimatedChunks = (flattened.Count + step - 1) / step;
44+
var results = new List<TextChunk>(estimatedChunks);
45+
46+
var documentIds = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
3847

3948
var start = 0;
4049
while (start < flattened.Count)
4150
{
4251
var end = Math.Min(flattened.Count, start + chunkSize);
43-
var chunkTokens = flattened.GetRange(start, end - start);
44-
var tokenValues = new int[chunkTokens.Count];
45-
for (var i = 0; i < chunkTokens.Count; i++)
52+
var chunkTokens = CollectionsMarshal.AsSpan(flattened).Slice(start, end - start);
53+
var tokenValues = ArrayPool<int>.Shared.Rent(chunkTokens.Length);
54+
documentIds.Clear();
55+
56+
var lastSliceIndex = -1;
57+
for (var i = 0; i < chunkTokens.Length; i++)
4658
{
59+
var sliceIndex = chunkTokens[i].SliceIndex;
4760
tokenValues[i] = chunkTokens[i].Token;
61+
62+
if (sliceIndex != lastSliceIndex)
63+
{
64+
documentIds.Add(slices[sliceIndex].DocumentId);
65+
lastSliceIndex = sliceIndex;
66+
}
4867
}
4968

50-
var decoded = tokenizer.Decode(tokenValues);
51-
var documentIds = chunkTokens
52-
.Select(tuple => slices[tuple.SliceIndex].DocumentId)
53-
.Distinct(StringComparer.OrdinalIgnoreCase)
54-
.ToArray();
69+
var decoded = tokenizer.Decode(new ArraySegment<int>(tokenValues, 0, chunkTokens.Length));
70+
results.Add(new TextChunk(documentIds.ToList(), decoded, chunkTokens.Length));
5571

56-
results.Add(new TextChunk(documentIds, decoded, tokenValues.Length));
72+
ArrayPool<int>.Shared.Return(tokenValues);
5773

5874
if (end >= flattened.Count)
5975
{

tests/ManagedCode.GraphRag.Tests/Chunking/TokenTextChunkerTests.cs

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ namespace ManagedCode.GraphRag.Tests.Chunking;
88
public sealed class TokenTextChunkerTests
99
{
1010
private readonly TokenTextChunker _chunker = new();
11+
private readonly ChunkingConfig _defaultConfig = new()
12+
{
13+
Size = 40,
14+
Overlap = 10,
15+
EncodingModel = TokenizerDefaults.DefaultEncoding
16+
};
1117

1218
[Fact]
1319
public void Chunk_RespectsTokenBudget()
@@ -63,4 +69,142 @@ public void Chunk_CombinesDocumentIdentifiersAcrossSlices()
6369
Assert.Contains(chunks, chunk => chunk.DocumentIds.Contains("doc-1"));
6470
Assert.Contains(chunks, chunk => chunk.DocumentIds.Contains("doc-2"));
6571
}
72+
73+
[Fact]
74+
public void Chunk_OverlapProducesSharedTokensBetweenAdjacentChunks()
75+
{
76+
var tokenizer = TokenizerRegistry.GetTokenizer(TokenizerDefaults.DefaultEncoding);
77+
const string text = "The quick brown fox jumps over the lazy dog and continues running through the forest until it reaches the river where it stops to drink some water.";
78+
var slices = new[] { new ChunkSlice("doc-1", text) };
79+
80+
var config = new ChunkingConfig
81+
{
82+
Size = 20,
83+
Overlap = 5,
84+
EncodingModel = TokenizerDefaults.DefaultEncoding
85+
};
86+
87+
var chunks = _chunker.Chunk(slices, config);
88+
89+
Assert.True(chunks.Count >= 2, "Need at least 2 chunks to verify overlap");
90+
91+
for (var i = 0; i < chunks.Count - 1; i++)
92+
{
93+
var currentChunkTokens = tokenizer.EncodeToIds(chunks[i].Text);
94+
var nextChunkTokens = tokenizer.EncodeToIds(chunks[i + 1].Text);
95+
96+
var lastTokensOfCurrent = currentChunkTokens.TakeLast(config.Overlap).ToArray();
97+
var firstTokensOfNext = nextChunkTokens.Take(config.Overlap).ToArray();
98+
99+
Assert.Equal(lastTokensOfCurrent, firstTokensOfNext);
100+
}
101+
}
102+
103+
[Fact]
104+
public void Chunk_EmptySlicesReturnsEmptyResult()
105+
{
106+
var slices = Array.Empty<ChunkSlice>();
107+
108+
var chunks = _chunker.Chunk(slices, _defaultConfig);
109+
110+
Assert.Empty(chunks);
111+
}
112+
113+
[Fact]
114+
public void Chunk_SlicesWithEmptyTextReturnsEmptyResult()
115+
{
116+
var slices = new[] { new ChunkSlice("doc-1", string.Empty) };
117+
118+
var chunks = _chunker.Chunk(slices, _defaultConfig);
119+
120+
Assert.Empty(chunks);
121+
}
122+
123+
[Fact]
124+
public void Chunk_NullSlicesThrowsArgumentNullException()
125+
{
126+
Assert.Throws<ArgumentNullException>(() => _chunker.Chunk(null!, _defaultConfig));
127+
}
128+
129+
[Fact]
130+
public void Chunk_NullConfigThrowsArgumentNullException()
131+
{
132+
var slices = new[] { new ChunkSlice("doc-1", "Some text") };
133+
134+
Assert.Throws<ArgumentNullException>(() => _chunker.Chunk(slices, null!));
135+
}
136+
137+
[Fact]
138+
public void Chunk_ZeroOverlapProducesNonOverlappingChunks()
139+
{
140+
var tokenizer = TokenizerRegistry.GetTokenizer(TokenizerDefaults.DefaultEncoding);
141+
const string text = "The quick brown fox jumps over the lazy dog and continues running through the forest until it reaches the river.";
142+
var slices = new[] { new ChunkSlice("doc-1", text) };
143+
144+
var config = new ChunkingConfig
145+
{
146+
Size = 15,
147+
Overlap = 0,
148+
EncodingModel = TokenizerDefaults.DefaultEncoding
149+
};
150+
151+
var chunks = _chunker.Chunk(slices, config);
152+
Assert.True(chunks.Count >= 2, "Need at least 2 chunks to verify zero overlap");
153+
154+
var allChunkTokens = chunks
155+
.SelectMany(c => tokenizer.EncodeToIds(c.Text))
156+
.ToList();
157+
158+
var originalTokens = tokenizer.EncodeToIds(text);
159+
160+
Assert.Equal(originalTokens.Count, allChunkTokens.Count);
161+
}
162+
163+
[Fact]
164+
public void Chunk_InputSmallerThanChunkSizeReturnsSingleChunk()
165+
{
166+
const string shortText = "Hello world";
167+
var slices = new[] { new ChunkSlice("doc-1", shortText) };
168+
169+
var config = new ChunkingConfig
170+
{
171+
Size = 100,
172+
Overlap = 10,
173+
EncodingModel = TokenizerDefaults.DefaultEncoding
174+
};
175+
176+
var chunks = _chunker.Chunk(slices, config);
177+
178+
Assert.Single(chunks);
179+
Assert.Equal(shortText, chunks[0].Text);
180+
}
181+
182+
[Fact]
183+
public void Chunk_ExactBoundaryProducesExpectedChunkCount()
184+
{
185+
var tokenizer = TokenizerRegistry.GetTokenizer(TokenizerDefaults.DefaultEncoding);
186+
187+
const int chunkSize = 10;
188+
const int overlap = 2;
189+
const int step = chunkSize - overlap;
190+
191+
var targetTokenCount = step * 3 + overlap;
192+
var words = Enumerable.Range(0, targetTokenCount * 2).Select(i => "word").ToArray();
193+
var text = string.Join(" ", words);
194+
195+
var actualTokens = tokenizer.EncodeToIds(text);
196+
var slices = new[] { new ChunkSlice("doc-1", text) };
197+
198+
var config = new ChunkingConfig
199+
{
200+
Size = chunkSize,
201+
Overlap = overlap,
202+
EncodingModel = TokenizerDefaults.DefaultEncoding
203+
};
204+
205+
var chunks = _chunker.Chunk(slices, config);
206+
207+
Assert.True(chunks.Count >= 2, "Should produce multiple chunks");
208+
Assert.All(chunks.SkipLast(1), chunk => Assert.Equal(chunkSize, chunk.TokenCount));
209+
}
66210
}

0 commit comments

Comments
 (0)