Skip to content

Commit b756361

Browse files
committed
Add custom splitter example
1 parent 4342c71 commit b756361

File tree

10 files changed

+132
-25
lines changed

10 files changed

+132
-25
lines changed

samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/rag/DocumentSplittingExample.java

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
import com.microsoft.semantic.kernel.rag.splitting.Chunk;
55
import com.microsoft.semantic.kernel.rag.splitting.Document;
66
import com.microsoft.semantic.kernel.rag.splitting.Splitter;
7+
import com.microsoft.semantic.kernel.rag.splitting.TextSplitter;
8+
import com.microsoft.semantic.kernel.rag.splitting.document.TextDocument;
9+
import com.microsoft.semantic.kernel.rag.splitting.overlap.NoOverlapCondition;
10+
import com.microsoft.semantic.kernel.rag.splitting.splitconditions.CountSplitCondition;
11+
import com.microsoft.semantic.kernel.rag.splitting.splitconditions.SplitPoint;
12+
import com.microsoft.semantickernel.implementation.EmbeddedResourceLoader;
713
import java.io.ByteArrayInputStream;
814
import java.io.IOException;
915
import java.net.URI;
@@ -12,11 +18,14 @@
1218
import java.net.http.HttpResponse;
1319
import java.net.http.HttpResponse.BodyHandlers;
1420
import java.util.List;
21+
import java.util.regex.Pattern;
22+
import java.util.stream.Collectors;
1523
import org.apache.pdfbox.io.RandomAccessReadBuffer;
1624
import org.apache.pdfbox.pdfparser.PDFParser;
1725
import org.apache.pdfbox.pdmodel.PDDocument;
1826
import org.apache.pdfbox.text.PDFTextStripper;
1927
import reactor.core.publisher.Flux;
28+
import reactor.core.publisher.Mono;
2029

2130
public class DocumentSplittingExample {
2231

@@ -46,6 +55,11 @@ public Flux<String> getContent() {
4655
}
4756

4857
public static void main(String[] args) throws IOException, InterruptedException {
58+
useCustomChunker();
59+
useInbuiltChunker();
60+
}
61+
62+
private static void useInbuiltChunker() throws IOException, InterruptedException {
4963
byte[] pdfBytes = getPdfDoc();
5064
PDFDocument pdfDoc = new PDFDocument(pdfBytes);
5165

@@ -68,6 +82,70 @@ public static void main(String[] args) throws IOException, InterruptedException
6882
});
6983
}
7084

85+
public static void useCustomChunker() throws IOException, InterruptedException {
86+
87+
String example = EmbeddedResourceLoader.readFile("example.md",
88+
DocumentSplittingExample.class);
89+
90+
// Define how we are splitting tokens, in this case we are splitting on headers of an md file
91+
// i.e <new line> followed by one or more # characters
92+
TextSplitter textSplitter = (doc, numTokens) -> {
93+
// Split on headers
94+
Pattern pattern = Pattern.compile("(\\r?\\n|\\r)\s*#+", Pattern.MULTILINE);
95+
96+
Flux<Integer> splitPoints = Flux.fromStream(pattern.matcher(doc).results())
97+
.map(window -> window.start());
98+
99+
return createWindows(doc, splitPoints);
100+
};
101+
102+
// Split into single sections
103+
CountSplitCondition condition = new CountSplitCondition(1, textSplitter);
104+
105+
Splitter splitter = Splitter
106+
.builder()
107+
.addChunkEndCondition(condition)
108+
// No overlap
109+
.setOverlapCondition(NoOverlapCondition.build())
110+
// Tidy up the text
111+
.trimWhitespace()
112+
.build();
113+
114+
String chunks = splitter
115+
.splitDocument(new TextDocument(example))
116+
.collectList()
117+
.map(it -> it.stream()
118+
.map(chunk -> chunk.getContents())
119+
.collect(Collectors.joining("\n============\n")))
120+
.block();
121+
122+
System.out.println(chunks);
123+
}
124+
125+
/*
126+
* Transforms: [ 2, 10, 20, 100 ] -> [ (0, 2), (2, 10), (10, 20), (20, 100), (100, <doc length>)
127+
* ]
128+
*/
129+
private static List<SplitPoint> createWindows(String doc, Flux<Integer> splitPoints) {
130+
return Flux.concat(
131+
Flux.just(0),
132+
splitPoints,
133+
Flux.just(doc.length()))
134+
.window(2, 1)
135+
.concatMap(window -> {
136+
return window.collectList()
137+
.flatMap(list -> {
138+
if (list.size() <= 1) {
139+
return Mono.empty();
140+
}
141+
return Mono.just(
142+
new SplitPoint(list.get(0), list.get(1)));
143+
});
144+
})
145+
.collectList()
146+
.block();
147+
}
148+
71149
private static byte[] getPdfDoc() throws IOException, InterruptedException {
72150
HttpResponse<byte[]> doc = HttpClient.newHttpClient()
73151
.send(HttpRequest.newBuilder()
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
## Section 1
2+
3+
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna
4+
aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis
5+
aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint
6+
occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
7+
8+
## Section 2
9+
10+
Another section.
11+
12+
### Subsection 1
13+
14+
1, 2, 3, 4, 5, 6, 7, 8, 9, 10.
15+
16+
# Section 3
17+
18+
This is the last section.
19+
20+
```
21+
some code
22+
```

samples/semantickernel-sample-plugins/semantickernel-text-splitter-plugin/src/main/java/com/microsoft/semantic/kernel/rag/splitting/Splitter.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,16 @@ private static List<Chunk> chunkDocument(List<ChunkEndCondition> chunkEndConditi
118118
"This entier chunk consists of overlapped data, this will result in infinite loop. Skipping this chunk.");
119119

120120
// previous chunk should already contain this text..skip it
121-
doc = doc.substring(previousChunkEndIndex, doc.length());
121+
doc = doc.substring(Math.min(previousChunkEndIndex, doc.length()),
122+
doc.length());
122123

123-
previousChunkEndIndex = 0;
124+
previousChunkEndIndex = -1;
124125
continue;
125126
}
126127

127128
int overlapIndex = overlapCondition.getOverlapIndex(chunkText);
128129
previousChunkEndIndex = chunkText.length() - overlapIndex;
129-
doc = doc.substring(overlapIndex, doc.length());
130+
doc = doc.substring(Math.min(overlapIndex, doc.length()), doc.length());
130131

131132
chunks.add(new Chunk(chunkText));
132133
} else {

samples/semantickernel-sample-plugins/semantickernel-text-splitter-plugin/src/main/java/com/microsoft/semantic/kernel/rag/splitting/TextSplitter.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
// Copyright (c) Microsoft. All rights reserved.
22
package com.microsoft.semantic.kernel.rag.splitting;
33

4-
import com.microsoft.semantic.kernel.rag.splitting.splitconditions.SplitPoints;
5-
4+
import com.microsoft.semantic.kernel.rag.splitting.splitconditions.SplitPoint;
65
import java.util.List;
76

87
/**
@@ -16,7 +15,9 @@ public interface TextSplitter {
1615
* @param doc the document to split
1716
* @return the split points
1817
*/
19-
List<SplitPoints> getSplitPoints(String doc);
18+
default List<SplitPoint> getSplitPoints(String doc) {
19+
return getNSplitPoints(doc, Integer.MAX_VALUE);
20+
}
2021

2122
/**
2223
* Get the first n split points for the given document.
@@ -25,5 +26,5 @@ public interface TextSplitter {
2526
* @param n the number of split points to get
2627
* @return the split points
2728
*/
28-
List<SplitPoints> getNSplitPoints(String doc, int n);
29+
List<SplitPoint> getNSplitPoints(String doc, int n);
2930
}

samples/semantickernel-sample-plugins/semantickernel-text-splitter-plugin/src/main/java/com/microsoft/semantic/kernel/rag/splitting/overlap/CountOverlapCondition.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import com.microsoft.semantic.kernel.rag.splitting.OverlapCondition;
55
import com.microsoft.semantic.kernel.rag.splitting.TextSplitter;
6-
import com.microsoft.semantic.kernel.rag.splitting.splitconditions.SplitPoints;
6+
import com.microsoft.semantic.kernel.rag.splitting.splitconditions.SplitPoint;
77
import java.util.List;
88

99
/**
@@ -25,7 +25,7 @@ public CountOverlapCondition(int count, TextSplitter splitter) {
2525

2626
@Override
2727
public int getOverlapIndex(String chunk) {
28-
List<SplitPoints> splitPoints = splitter.getSplitPoints(chunk);
28+
List<SplitPoint> splitPoints = splitter.getSplitPoints(chunk);
2929

3030
if (splitPoints.size() == 0) {
3131
return 0;

samples/semantickernel-sample-plugins/semantickernel-text-splitter-plugin/src/main/java/com/microsoft/semantic/kernel/rag/splitting/overlap/NoOverlapCondition.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ public class NoOverlapCondition implements OverlapCondition {
1111
public NoOverlapCondition() {
1212
}
1313

14+
public static OverlapCondition build() {
15+
return new NoOverlapCondition();
16+
}
17+
1418
@Override
1519
public int getOverlapIndex(String chunk) {
1620
return chunk.length();

samples/semantickernel-sample-plugins/semantickernel-text-splitter-plugin/src/main/java/com/microsoft/semantic/kernel/rag/splitting/overlap/PercentageOverlapCondition.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import com.microsoft.semantic.kernel.rag.splitting.OverlapCondition;
55
import com.microsoft.semantic.kernel.rag.splitting.TextSplitter;
6-
import com.microsoft.semantic.kernel.rag.splitting.splitconditions.SplitPoints;
6+
import com.microsoft.semantic.kernel.rag.splitting.splitconditions.SplitPoint;
77
import java.util.List;
88
import org.slf4j.Logger;
99

@@ -31,11 +31,11 @@ public PercentageOverlapCondition(float percentage, TextSplitter splitter) {
3131

3232
@Override
3333
public int getOverlapIndex(String chunk) {
34-
List<SplitPoints> splitPoints = splitter.getSplitPoints(chunk);
34+
List<SplitPoint> splitPoints = splitter.getSplitPoints(chunk);
3535

3636
float index = chunk.length() * (100.0f - percentage) / 100.0f;
3737

38-
for (SplitPoints splitPoint : splitPoints) {
38+
for (SplitPoint splitPoint : splitPoints) {
3939
if (splitPoint.getEnd() > index) {
4040
return splitPoint.getStart();
4141
}

samples/semantickernel-sample-plugins/semantickernel-text-splitter-plugin/src/main/java/com/microsoft/semantic/kernel/rag/splitting/splitconditions/CountSplitCondition.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@ public CountSplitCondition(int count, TextSplitter splitter) {
2121

2222
@Override
2323
public int getEndOfNextChunk(String doc) {
24-
List<SplitPoints> splitPoints = splitter.getNSplitPoints(doc, count);
24+
List<SplitPoint> splitPoints = splitter.getNSplitPoints(doc, count)
25+
.stream()
26+
.filter(it -> it != null)
27+
.filter(it -> it.getEnd() != 0)
28+
.filter(it -> it.getEnd() != it.getStart())
29+
.filter(it -> it.getStart() != doc.length())
30+
.toList();
2531

2632
if (splitPoints.size() < count) {
2733
return splitPoints.get(splitPoints.size() - 1).getEnd();

samples/semantickernel-sample-plugins/semantickernel-text-splitter-plugin/src/main/java/com/microsoft/semantic/kernel/rag/splitting/splitconditions/RegexSplitter.java

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,13 @@ public RegexSplitter(Pattern pattern, int trivialSplitLength) {
4444
}
4545

4646
@Override
47-
public List<SplitPoints> getSplitPoints(String doc) {
48-
return getNSplitPoints(doc, Integer.MAX_VALUE);
49-
}
50-
51-
@Override
52-
public List<SplitPoints> getNSplitPoints(String doc, int n) {
47+
public List<SplitPoint> getNSplitPoints(String doc, int n) {
5348
Matcher matcher = pattern.matcher(doc);
5449

5550
List<MatchResult> points = matcher.results()
5651
.collect(Collectors.toList());
5752

58-
List<SplitPoints> result = new ArrayList<>();
53+
List<SplitPoint> result = new ArrayList<>();
5954

6055
int previousEnd = 0;
6156
for (MatchResult point : points) {
@@ -66,19 +61,19 @@ public List<SplitPoints> getNSplitPoints(String doc, int n) {
6661
trivialSplitLength)) {
6762
continue;
6863
}
69-
result.add(new SplitPoints(previousEnd, point.end()));
64+
result.add(new SplitPoint(previousEnd, point.end()));
7065
previousEnd = point.end();
7166
if (result.size() >= n) {
7267
break;
7368
}
7469
}
7570

7671
if (result.size() < n && !isTrivialSplit(previousEnd, doc.length(), doc, 1)) {
77-
result.add(new SplitPoints(previousEnd, doc.length()));
72+
result.add(new SplitPoint(previousEnd, doc.length()));
7873
}
7974

8075
if (result.isEmpty()) {
81-
return List.of(new SplitPoints(0, doc.length()));
76+
return List.of(new SplitPoint(0, doc.length()));
8277
}
8378

8479
return result;
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
* A class that represents the start and end points of a split. I.e if splitting by word, these
66
* would be the indices of the first and last char in the word within the chunk.
77
*/
8-
public class SplitPoints {
8+
public class SplitPoint {
99

1010
private final int start;
1111
private final int end;
1212

13-
public SplitPoints(int start, int end) {
13+
public SplitPoint(int start, int end) {
1414
this.start = start;
1515
this.end = end;
1616
}

0 commit comments

Comments
 (0)