Skip to content

Commit 436ec11

Browse files
authored
Text similarity reranker chunks and scores snippets (#133576)
1 parent 0699e77 commit 436ec11

File tree

17 files changed

+640
-294
lines changed

17 files changed

+640
-294
lines changed

docs/changelog/133576.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 133576
2+
summary: Text similarity reranker chunks and scores snippets
3+
area: Relevance
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,6 @@ public interface ChunkingSettings extends ToXContentObject, VersionedNamedWritea
2424
* @return The max chunk size specified, or null if not specified
2525
*/
2626
Integer maxChunkSize();
27+
28+
default void validate() {}
2729
}

x-pack/plugin/core/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@
234234
exports org.elasticsearch.xpack.core.watcher.watch;
235235
exports org.elasticsearch.xpack.core.watcher;
236236
exports org.elasticsearch.xpack.core.security.authc.apikey;
237+
exports org.elasticsearch.xpack.core.common.chunks;
237238

238239
provides org.elasticsearch.action.admin.cluster.node.info.ComponentVersionNumber
239240
with
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.common.chunks;
9+
10+
import org.apache.lucene.analysis.standard.StandardAnalyzer;
11+
import org.apache.lucene.document.Document;
12+
import org.apache.lucene.document.Field;
13+
import org.apache.lucene.document.TextField;
14+
import org.apache.lucene.index.DirectoryReader;
15+
import org.apache.lucene.index.IndexWriter;
16+
import org.apache.lucene.index.IndexWriterConfig;
17+
import org.apache.lucene.search.BooleanClause;
18+
import org.apache.lucene.search.IndexSearcher;
19+
import org.apache.lucene.search.Query;
20+
import org.apache.lucene.search.ScoreDoc;
21+
import org.apache.lucene.search.TopDocs;
22+
import org.apache.lucene.store.ByteBuffersDirectory;
23+
import org.apache.lucene.store.Directory;
24+
import org.apache.lucene.util.QueryBuilder;
25+
26+
import java.io.IOException;
27+
import java.util.ArrayList;
28+
import java.util.List;
29+
30+
/**
31+
* Utility class for scoring pre-determined chunks using an in-memory Lucene index.
32+
*/
33+
public class MemoryIndexChunkScorer {
34+
35+
private static final String CONTENT_FIELD = "content";
36+
37+
private final StandardAnalyzer analyzer;
38+
39+
public MemoryIndexChunkScorer() {
40+
// TODO: Allow analyzer to be customizable and/or read from the field mapping
41+
this.analyzer = new StandardAnalyzer();
42+
}
43+
44+
/**
45+
* Creates an in-memory index of chunks, or chunks, returns ordered, scored list.
46+
*
47+
* @param chunks the list of text chunks to score
48+
* @param inferenceText the query text to compare against
49+
* @param maxResults maximum number of results to return
50+
* @return list of scored chunks ordered by relevance
51+
* @throws IOException on failure scoring chunks
52+
*/
53+
public List<ScoredChunk> scoreChunks(List<String> chunks, String inferenceText, int maxResults) throws IOException {
54+
if (chunks == null || chunks.isEmpty() || inferenceText == null || inferenceText.trim().isEmpty()) {
55+
return new ArrayList<>();
56+
}
57+
58+
try (Directory directory = new ByteBuffersDirectory()) {
59+
IndexWriterConfig config = new IndexWriterConfig(analyzer);
60+
try (IndexWriter writer = new IndexWriter(directory, config)) {
61+
for (String chunk : chunks) {
62+
Document doc = new Document();
63+
doc.add(new TextField(CONTENT_FIELD, chunk, Field.Store.YES));
64+
writer.addDocument(doc);
65+
}
66+
writer.commit();
67+
}
68+
69+
try (DirectoryReader reader = DirectoryReader.open(directory)) {
70+
IndexSearcher searcher = new IndexSearcher(reader);
71+
72+
org.apache.lucene.util.QueryBuilder qb = new QueryBuilder(analyzer);
73+
Query query = qb.createBooleanQuery(CONTENT_FIELD, inferenceText, BooleanClause.Occur.SHOULD);
74+
int numResults = Math.min(maxResults, chunks.size());
75+
TopDocs topDocs = searcher.search(query, numResults);
76+
77+
List<ScoredChunk> scoredChunks = new ArrayList<>();
78+
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
79+
Document doc = reader.storedFields().document(scoreDoc.doc);
80+
String content = doc.get(CONTENT_FIELD);
81+
scoredChunks.add(new ScoredChunk(content, scoreDoc.score));
82+
}
83+
84+
// It's possible that no chunks were scorable (for example, a semantic match that does not have a lexical match).
85+
// In this case, we'll return the first N chunks with a score of 0.
86+
// TODO: consider parameterizing this
87+
return scoredChunks.isEmpty() == false
88+
? scoredChunks
89+
: chunks.subList(0, Math.min(maxResults, chunks.size())).stream().map(c -> new ScoredChunk(c, 0.0f)).toList();
90+
}
91+
}
92+
}
93+
94+
/**
95+
* Represents a chunk with its relevance score.
96+
*/
97+
public record ScoredChunk(String content, float score) {}
98+
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.common.chunks;
9+
10+
import org.elasticsearch.test.ESTestCase;
11+
12+
import java.io.IOException;
13+
import java.util.Arrays;
14+
import java.util.List;
15+
16+
import static org.hamcrest.Matchers.equalTo;
17+
import static org.hamcrest.Matchers.greaterThan;
18+
19+
public class MemoryIndexChunkScorerTests extends ESTestCase {
20+
21+
private static final List<String> CHUNKS = Arrays.asList(
22+
"Cats like to sleep all day and play with mice",
23+
"Dogs are loyal companions and great pets",
24+
"The weather today is very sunny and warm",
25+
"Dogs love to play with toys and go for walks",
26+
"Elasticsearch is a great search engine"
27+
);
28+
29+
public void testScoreChunks() throws IOException {
30+
MemoryIndexChunkScorer scorer = new MemoryIndexChunkScorer();
31+
32+
String inferenceText = "dogs play walk";
33+
int maxResults = 3;
34+
35+
List<MemoryIndexChunkScorer.ScoredChunk> scoredChunks = scorer.scoreChunks(CHUNKS, inferenceText, maxResults);
36+
37+
assertEquals(maxResults, scoredChunks.size());
38+
39+
// The chunks about dogs should score highest, followed by the chunk about cats
40+
MemoryIndexChunkScorer.ScoredChunk chunk = scoredChunks.getFirst();
41+
assertTrue(chunk.content().equalsIgnoreCase("Dogs love to play with toys and go for walks"));
42+
assertThat(chunk.score(), greaterThan(0f));
43+
44+
chunk = scoredChunks.get(1);
45+
assertTrue(chunk.content().equalsIgnoreCase("Dogs are loyal companions and great pets"));
46+
assertThat(chunk.score(), greaterThan(0f));
47+
48+
chunk = scoredChunks.get(2);
49+
assertTrue(chunk.content().equalsIgnoreCase("Cats like to sleep all day and play with mice"));
50+
assertThat(chunk.score(), greaterThan(0f));
51+
52+
// Scores should be in descending order
53+
for (int i = 1; i < scoredChunks.size(); i++) {
54+
assertTrue(scoredChunks.get(i - 1).score() >= scoredChunks.get(i).score());
55+
}
56+
}
57+
58+
public void testEmptyChunks() throws IOException {
59+
60+
int maxResults = 3;
61+
62+
MemoryIndexChunkScorer scorer = new MemoryIndexChunkScorer();
63+
64+
// Zero results
65+
List<MemoryIndexChunkScorer.ScoredChunk> scoredChunks = scorer.scoreChunks(CHUNKS, "puggles", maxResults);
66+
assertEquals(maxResults, scoredChunks.size());
67+
68+
// There were no results so we return the first N chunks in order
69+
MemoryIndexChunkScorer.ScoredChunk chunk = scoredChunks.getFirst();
70+
assertTrue(chunk.content().equalsIgnoreCase("Cats like to sleep all day and play with mice"));
71+
assertThat(chunk.score(), equalTo(0f));
72+
73+
chunk = scoredChunks.get(1);
74+
assertTrue(chunk.content().equalsIgnoreCase("Dogs are loyal companions and great pets"));
75+
assertThat(chunk.score(), equalTo(0f));
76+
77+
chunk = scoredChunks.get(2);
78+
assertTrue(chunk.content().equalsIgnoreCase("The weather today is very sunny and warm"));
79+
assertThat(chunk.score(), equalTo(0f));
80+
81+
// Null and Empty chunk input
82+
scoredChunks = scorer.scoreChunks(List.of(), "puggles", maxResults);
83+
assertTrue(scoredChunks.isEmpty());
84+
85+
scoredChunks = scorer.scoreChunks(CHUNKS, "", maxResults);
86+
assertTrue(scoredChunks.isEmpty());
87+
88+
scoredChunks = scorer.scoreChunks(null, "puggles", maxResults);
89+
assertTrue(scoredChunks.isEmpty());
90+
91+
scoredChunks = scorer.scoreChunks(CHUNKS, null, maxResults);
92+
assertTrue(scoredChunks.isEmpty());
93+
}
94+
95+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/RecursiveChunkingSettings.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,25 @@ public RecursiveChunkingSettings(StreamInput in) throws IOException {
5252
separators = in.readCollectionAsList(StreamInput::readString);
5353
}
5454

55+
@Override
56+
public void validate() {
57+
ValidationException validationException = new ValidationException();
58+
59+
if (maxChunkSize < MAX_CHUNK_SIZE_LOWER_LIMIT) {
60+
validationException.addValidationError(
61+
ChunkingSettingsOptions.MAX_CHUNK_SIZE + "[" + maxChunkSize + "] must be above " + MAX_CHUNK_SIZE_LOWER_LIMIT
62+
);
63+
64+
if (separators != null && separators.isEmpty()) {
65+
validationException.addValidationError("Recursive chunking settings can not have an empty list of separators");
66+
}
67+
68+
if (validationException.validationErrors().isEmpty() == false) {
69+
throw validationException;
70+
}
71+
}
72+
}
73+
5574
public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
5675
ValidationException validationException = new ValidationException();
5776

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,27 @@ public Integer maxChunkSize() {
5959
return maxChunkSize;
6060
}
6161

62+
@Override
63+
public void validate() {
64+
ValidationException validationException = new ValidationException();
65+
66+
if (maxChunkSize < MAX_CHUNK_SIZE_LOWER_LIMIT) {
67+
validationException.addValidationError(
68+
ChunkingSettingsOptions.MAX_CHUNK_SIZE + "[" + maxChunkSize + "] must be above " + MAX_CHUNK_SIZE_LOWER_LIMIT
69+
);
70+
}
71+
72+
if (sentenceOverlap > 1 || sentenceOverlap < 0) {
73+
validationException.addValidationError(
74+
ChunkingSettingsOptions.SENTENCE_OVERLAP + "[" + sentenceOverlap + "] must be either 0 or 1"
75+
);
76+
}
77+
78+
if (validationException.validationErrors().isEmpty() == false) {
79+
throw validationException;
80+
}
81+
}
82+
6283
@Override
6384
public Map<String, Object> asMap() {
6485
return Map.of(

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,27 @@ public WordBoundaryChunkingSettings(StreamInput in) throws IOException {
4848
overlap = in.readInt();
4949
}
5050

51+
@Override
52+
public void validate() {
53+
ValidationException validationException = new ValidationException();
54+
55+
if (maxChunkSize < MAX_CHUNK_SIZE_LOWER_LIMIT) {
56+
validationException.addValidationError(
57+
ChunkingSettingsOptions.MAX_CHUNK_SIZE + "[" + maxChunkSize + "] must be above " + MAX_CHUNK_SIZE_LOWER_LIMIT
58+
);
59+
}
60+
61+
if (overlap > maxChunkSize / 2) {
62+
validationException.addValidationError(
63+
ChunkingSettingsOptions.OVERLAP + "[" + overlap + "] must be less than or equal to half of max chunk size"
64+
);
65+
}
66+
67+
if (validationException.validationErrors().isEmpty() == false) {
68+
throw validationException;
69+
}
70+
}
71+
5172
@Override
5273
public Map<String, Object> asMap() {
5374
return Map.of(
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.rank.textsimilarity;
9+
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.inference.ChunkingSettings;
14+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
15+
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
16+
17+
import java.io.IOException;
18+
import java.util.Map;
19+
import java.util.Objects;
20+
21+
public class ChunkScorerConfig implements Writeable {
22+
23+
public final Integer size;
24+
private final String inferenceText;
25+
private final ChunkingSettings chunkingSettings;
26+
27+
public static final int DEFAULT_CHUNK_SIZE = 300;
28+
public static final int DEFAULT_SIZE = 1;
29+
30+
public static ChunkingSettings createChunkingSettings(Integer chunkSize) {
31+
int chunkSizeOrDefault = chunkSize != null ? chunkSize : DEFAULT_CHUNK_SIZE;
32+
ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(chunkSizeOrDefault, 0);
33+
chunkingSettings.validate();
34+
return chunkingSettings;
35+
}
36+
37+
public static ChunkingSettings chunkingSettingsFromMap(Map<String, Object> map) {
38+
39+
if (map == null || map.isEmpty()) {
40+
return createChunkingSettings(DEFAULT_CHUNK_SIZE);
41+
}
42+
43+
if (map.size() == 1 && map.containsKey("max_chunk_size")) {
44+
return createChunkingSettings((Integer) map.get("max_chunk_size"));
45+
}
46+
47+
return ChunkingSettingsBuilder.fromMap(map);
48+
}
49+
50+
public ChunkScorerConfig(StreamInput in) throws IOException {
51+
this.size = in.readOptionalVInt();
52+
this.inferenceText = in.readString();
53+
Map<String, Object> chunkingSettingsMap = in.readGenericMap();
54+
this.chunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap);
55+
}
56+
57+
public ChunkScorerConfig(Integer size, ChunkingSettings chunkingSettings) {
58+
this(size, null, chunkingSettings);
59+
}
60+
61+
public ChunkScorerConfig(Integer size, String inferenceText, ChunkingSettings chunkingSettings) {
62+
this.size = size;
63+
this.inferenceText = inferenceText;
64+
this.chunkingSettings = chunkingSettings;
65+
}
66+
67+
@Override
68+
public void writeTo(StreamOutput out) throws IOException {
69+
out.writeOptionalVInt(size);
70+
out.writeString(inferenceText);
71+
out.writeGenericMap(chunkingSettings.asMap());
72+
}
73+
74+
public Integer size() {
75+
return size;
76+
}
77+
78+
public String inferenceText() {
79+
return inferenceText;
80+
}
81+
82+
public ChunkingSettings chunkingSettings() {
83+
return chunkingSettings;
84+
}
85+
86+
@Override
87+
public boolean equals(Object o) {
88+
if (this == o) return true;
89+
if (o == null || getClass() != o.getClass()) return false;
90+
ChunkScorerConfig that = (ChunkScorerConfig) o;
91+
return Objects.equals(size, that.size)
92+
&& Objects.equals(inferenceText, that.inferenceText)
93+
&& Objects.equals(chunkingSettings, that.chunkingSettings);
94+
}
95+
96+
@Override
97+
public int hashCode() {
98+
return Objects.hash(size, inferenceText, chunkingSettings);
99+
}
100+
}

0 commit comments

Comments
 (0)