Skip to content

Commit 0cee213

Browse files
dan-rubinsteinelasticmachinedavidkyle
authored
Add RerankRequestChunker (#130485)
* Add RerankRequestChunker * Add chunking strategy generation * Adding unit tests and fixing token/word ratio * Add configurable values for long document handling strategy and maximum chunks per document * Adding back sentence overlap for rerank chunking strategy * Adding unit tests, transport version, and feature flag * Update docs/changelog/130485.yaml * Adding unit tests and refactoring code with clearer naming conventions --------- Co-authored-by: Elastic Machine <[email protected]> Co-authored-by: David Kyle <[email protected]>
1 parent 552da22 commit 0cee213

File tree

13 files changed

+1131
-22
lines changed

13 files changed

+1131
-22
lines changed

docs/changelog/130485.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 130485
2+
summary: Add `RerankRequestChunker`
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
9180000
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
ml_inference_google_model_garden_added,9179000
1+
elastic_reranker_chunking_configuration,9180000

test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ public enum FeatureFlag {
2525
"es.index_dimensions_tsid_optimization_feature_flag_enabled=true",
2626
Version.fromString("9.2.0"),
2727
null
28-
);
28+
),
29+
ELASTIC_RERANKER_CHUNKING("es.elastic_reranker_chunking_long_documents=true", Version.fromString("9.2.0"), null);
2930

3031
public final String systemProperty;
3132
public final Version from;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ public class ChunkingSettingsBuilder {
1717
public static final SentenceBoundaryChunkingSettings DEFAULT_SETTINGS = new SentenceBoundaryChunkingSettings(250, 1);
1818
// Old settings used for backward compatibility for endpoints created before 8.16 when default was changed
1919
public static final WordBoundaryChunkingSettings OLD_DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100);
20+
public static final int ELASTIC_RERANKER_TOKEN_LIMIT = 512;
21+
public static final int ELASTIC_RERANKER_EXTRA_TOKEN_COUNT = 3;
22+
public static final float WORDS_PER_TOKEN = 0.75f;
2023

2124
public static ChunkingSettings fromMap(Map<String, Object> settings) {
2225
return fromMap(settings, true);
@@ -51,4 +54,17 @@ public static ChunkingSettings fromMap(Map<String, Object> settings, boolean ret
5154
case RECURSIVE -> RecursiveChunkingSettings.fromMap(new HashMap<>(settings));
5255
};
5356
}
57+
58+
public static ChunkingSettings buildChunkingSettingsForElasticRerank(int queryWordCount) {
59+
var queryTokenCount = Math.ceil(queryWordCount / WORDS_PER_TOKEN);
60+
var chunkSizeTokenCountWithFullQuery = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT - queryTokenCount);
61+
62+
var maxChunkSizeTokenCount = Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2);
63+
if (chunkSizeTokenCountWithFullQuery > maxChunkSizeTokenCount) {
64+
maxChunkSizeTokenCount = chunkSizeTokenCountWithFullQuery;
65+
}
66+
67+
var maxChunkSizeWordCount = (int) (maxChunkSizeTokenCount * WORDS_PER_TOKEN);
68+
return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 1);
69+
}
5470
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.chunking;
9+
10+
import com.ibm.icu.text.BreakIterator;
11+
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.ChunkingSettings;
14+
import org.elasticsearch.inference.InferenceServiceResults;
15+
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
16+
17+
import java.util.ArrayList;
18+
import java.util.HashSet;
19+
import java.util.List;
20+
import java.util.Set;
21+
22+
public class RerankRequestChunker {
23+
private final List<String> inputs;
24+
private final List<RerankChunks> rerankChunks;
25+
26+
public RerankRequestChunker(String query, List<String> inputs, Integer maxChunksPerDoc) {
27+
this.inputs = inputs;
28+
this.rerankChunks = chunk(inputs, buildChunkingSettingsForElasticRerank(query), maxChunksPerDoc);
29+
}
30+
31+
private List<RerankChunks> chunk(List<String> inputs, ChunkingSettings chunkingSettings, Integer maxChunksPerDoc) {
32+
var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
33+
var chunks = new ArrayList<RerankChunks>();
34+
for (int i = 0; i < inputs.size(); i++) {
35+
var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings);
36+
if (maxChunksPerDoc != null && chunksForInput.size() > maxChunksPerDoc) {
37+
chunksForInput = chunksForInput.subList(0, maxChunksPerDoc);
38+
}
39+
40+
for (var chunk : chunksForInput) {
41+
chunks.add(new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end())));
42+
}
43+
}
44+
return chunks;
45+
}
46+
47+
public List<String> getChunkedInputs() {
48+
List<String> chunkedInputs = new ArrayList<>();
49+
for (RerankChunks chunk : rerankChunks) {
50+
chunkedInputs.add(chunk.chunkString());
51+
}
52+
53+
return chunkedInputs;
54+
}
55+
56+
public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener(ActionListener<InferenceServiceResults> listener) {
57+
return ActionListener.wrap(results -> {
58+
if (results instanceof RankedDocsResults rankedDocsResults) {
59+
listener.onResponse(parseRankedDocResultsForChunks(rankedDocsResults));
60+
61+
} else {
62+
listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass()));
63+
}
64+
65+
}, listener::onFailure);
66+
}
67+
68+
private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) {
69+
List<RankedDocsResults.RankedDoc> topRankedDocs = new ArrayList<>();
70+
Set<Integer> docIndicesSeen = new HashSet<>();
71+
72+
List<RankedDocsResults.RankedDoc> rankedDocs = new ArrayList<>(rankedDocsResults.getRankedDocs());
73+
rankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore()));
74+
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
75+
int chunkIndex = rankedDoc.index();
76+
int docIndex = rerankChunks.get(chunkIndex).docIndex();
77+
78+
if (docIndicesSeen.contains(docIndex) == false) {
79+
// Create a ranked doc with the full input string and the index for the document instead of the chunk
80+
RankedDocsResults.RankedDoc updatedRankedDoc = new RankedDocsResults.RankedDoc(
81+
docIndex,
82+
rankedDoc.relevanceScore(),
83+
inputs.get(docIndex)
84+
);
85+
topRankedDocs.add(updatedRankedDoc);
86+
docIndicesSeen.add(docIndex);
87+
}
88+
}
89+
90+
return new RankedDocsResults(topRankedDocs);
91+
}
92+
93+
public record RerankChunks(int docIndex, String chunkString) {};
94+
95+
private ChunkingSettings buildChunkingSettingsForElasticRerank(String query) {
96+
var wordIterator = BreakIterator.getWordInstance();
97+
wordIterator.setText(query);
98+
var queryWordCount = ChunkerUtils.countWords(0, query.length(), wordIterator);
99+
return ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount);
100+
}
101+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerServiceSettings.java

Lines changed: 124 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,51 @@
77

88
package org.elasticsearch.xpack.inference.services.elasticsearch;
99

10+
import org.elasticsearch.TransportVersion;
1011
import org.elasticsearch.common.ValidationException;
1112
import org.elasticsearch.common.io.stream.StreamInput;
13+
import org.elasticsearch.common.io.stream.StreamOutput;
14+
import org.elasticsearch.inference.ModelConfigurations;
15+
import org.elasticsearch.xcontent.XContentBuilder;
1216
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
1317

1418
import java.io.IOException;
19+
import java.util.EnumSet;
20+
import java.util.Locale;
1521
import java.util.Map;
1622

23+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
24+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
25+
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.ELASTIC_RERANKER_CHUNKING;
1726
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.RERANKER_ID;
1827

1928
public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings {
2029

2130
public static final String NAME = "elastic_reranker_service_settings";
2231

32+
public static final String LONG_DOCUMENT_STRATEGY = "long_document_strategy";
33+
public static final String MAX_CHUNKS_PER_DOC = "max_chunks_per_doc";
34+
35+
private static final TransportVersion ELASTIC_RERANKER_CHUNKING_CONFIGURATION = TransportVersion.fromName(
36+
"elastic_reranker_chunking_configuration"
37+
);
38+
39+
private final LongDocumentStrategy longDocumentStrategy;
40+
private final Integer maxChunksPerDoc;
41+
2342
public static ElasticRerankerServiceSettings defaultEndpointSettings() {
2443
return new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32));
2544
}
2645

27-
public ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings other) {
46+
public ElasticRerankerServiceSettings(
47+
ElasticsearchInternalServiceSettings other,
48+
LongDocumentStrategy longDocumentStrategy,
49+
Integer maxChunksPerDoc
50+
) {
2851
super(other);
52+
this.longDocumentStrategy = longDocumentStrategy;
53+
this.maxChunksPerDoc = maxChunksPerDoc;
54+
2955
}
3056

3157
private ElasticRerankerServiceSettings(
@@ -35,10 +61,32 @@ private ElasticRerankerServiceSettings(
3561
AdaptiveAllocationsSettings adaptiveAllocationsSettings
3662
) {
3763
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null);
64+
this.longDocumentStrategy = null;
65+
this.maxChunksPerDoc = null;
66+
}
67+
68+
protected ElasticRerankerServiceSettings(
69+
Integer numAllocations,
70+
int numThreads,
71+
String modelId,
72+
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
73+
LongDocumentStrategy longDocumentStrategy,
74+
Integer maxChunksPerDoc
75+
) {
76+
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null);
77+
this.longDocumentStrategy = longDocumentStrategy;
78+
this.maxChunksPerDoc = maxChunksPerDoc;
3879
}
3980

4081
public ElasticRerankerServiceSettings(StreamInput in) throws IOException {
4182
super(in);
83+
if (in.getTransportVersion().supports(ELASTIC_RERANKER_CHUNKING_CONFIGURATION)) {
84+
this.longDocumentStrategy = in.readOptionalEnum(LongDocumentStrategy.class);
85+
this.maxChunksPerDoc = in.readOptionalInt();
86+
} else {
87+
this.longDocumentStrategy = null;
88+
this.maxChunksPerDoc = null;
89+
}
4290
}
4391

4492
/**
@@ -48,21 +96,93 @@ public ElasticRerankerServiceSettings(StreamInput in) throws IOException {
4896
* {@link ValidationException} is thrown.
4997
*
5098
* @param map Source map containing the config
51-
* @return The builder
99+
* @return Parsed and validated service settings
52100
*/
53-
public static Builder fromRequestMap(Map<String, Object> map) {
101+
public static ElasticRerankerServiceSettings fromMap(Map<String, Object> map) {
54102
ValidationException validationException = new ValidationException();
55103
var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException);
56104

105+
LongDocumentStrategy longDocumentStrategy = null;
106+
Integer maxChunksPerDoc = null;
107+
if (ELASTIC_RERANKER_CHUNKING.isEnabled()) {
108+
longDocumentStrategy = extractOptionalEnum(
109+
map,
110+
LONG_DOCUMENT_STRATEGY,
111+
ModelConfigurations.SERVICE_SETTINGS,
112+
LongDocumentStrategy::fromString,
113+
EnumSet.allOf(LongDocumentStrategy.class),
114+
validationException
115+
);
116+
117+
maxChunksPerDoc = extractOptionalPositiveInteger(
118+
map,
119+
MAX_CHUNKS_PER_DOC,
120+
ModelConfigurations.SERVICE_SETTINGS,
121+
validationException
122+
);
123+
124+
if (maxChunksPerDoc != null && (longDocumentStrategy == null || longDocumentStrategy == LongDocumentStrategy.TRUNCATE)) {
125+
validationException.addValidationError(
126+
"The [" + MAX_CHUNKS_PER_DOC + "] setting requires [" + LONG_DOCUMENT_STRATEGY + "] to be set to [chunk]"
127+
);
128+
}
129+
}
130+
57131
if (validationException.validationErrors().isEmpty() == false) {
58132
throw validationException;
59133
}
60134

61-
return baseSettings;
135+
return new ElasticRerankerServiceSettings(baseSettings.build(), longDocumentStrategy, maxChunksPerDoc);
136+
}
137+
138+
public LongDocumentStrategy getLongDocumentStrategy() {
139+
return longDocumentStrategy;
140+
}
141+
142+
public Integer getMaxChunksPerDoc() {
143+
return maxChunksPerDoc;
144+
}
145+
146+
@Override
147+
public void writeTo(StreamOutput out) throws IOException {
148+
super.writeTo(out);
149+
if (out.getTransportVersion().supports(ELASTIC_RERANKER_CHUNKING_CONFIGURATION)) {
150+
out.writeOptionalEnum(longDocumentStrategy);
151+
out.writeOptionalInt(maxChunksPerDoc);
152+
}
153+
}
154+
155+
@Override
156+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
157+
builder.startObject();
158+
addInternalSettingsToXContent(builder, params);
159+
if (longDocumentStrategy != null) {
160+
builder.field(LONG_DOCUMENT_STRATEGY, longDocumentStrategy.strategyName);
161+
}
162+
if (maxChunksPerDoc != null) {
163+
builder.field(MAX_CHUNKS_PER_DOC, maxChunksPerDoc);
164+
}
165+
builder.endObject();
166+
return builder;
62167
}
63168

64169
@Override
65170
public String getWriteableName() {
66171
return ElasticRerankerServiceSettings.NAME;
67172
}
173+
174+
public enum LongDocumentStrategy {
175+
CHUNK("chunk"),
176+
TRUNCATE("truncate");
177+
178+
public final String strategyName;
179+
180+
LongDocumentStrategy(String strategyName) {
181+
this.strategyName = strategyName;
182+
}
183+
184+
public static LongDocumentStrategy fromString(String name) {
185+
return valueOf(name.trim().toUpperCase(Locale.ROOT));
186+
}
187+
}
68188
}

0 commit comments

Comments
 (0)