Skip to content

Commit a0ec9d2

Browse files
committed
Defer chunking embedding input strings
This commit restores the behaviour introduced in elastic#125837 which was inadvertently undone by changes in elastic#121041, specifically, delaying copying Strings as part of calling Request.chunkText() until the request is being executed. In addition to the above change, refactor doChunkedInfer() and its implementations to take a List<ChunkInferenceInput> rather than EmbeddingsInput, since the EmbeddingsInput passed into doChunkedInfer() was immediately discarded after extracting the ChunkInferenceInput list from it. This change allowed the EmbeddingsInput class to be refactored to not know about ChunkInferenceInput, simplifying it significantly. This commit also simplifies EmbeddingRequestChunker.Request to take only the input String rather than the entire list of all inputs, since only one input is actually needed. This change prevents Requests from retaining a reference to the input list, potentially allowing it to be GC'd faster.
1 parent 0022605 commit a0ec9d2

File tree

73 files changed

+254
-338
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+254
-338
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@
3737
* a single large input that has been chunked may spread over
3838
* multiple batches.
3939
*
40-
* The final aspect it to gather the responses from the batch
40+
* The final aspect is to gather the responses from the batch
4141
* processing and map the results back to the original element
4242
* in the input list.
4343
*/
4444
public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
4545

4646
// Visible for testing
47-
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<ChunkInferenceInput> inputs) {
47+
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input) {
4848
public String chunkText() {
49-
return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end());
49+
return input.substring(chunk.start(), chunk.end());
5050
}
5151
}
5252

@@ -60,7 +60,7 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener<Inferen
6060

6161
private static final ChunkingSettings DEFAULT_CHUNKING_SETTINGS = new WordBoundaryChunkingSettings(250, 100);
6262

63-
// The maximum number of chunks that is stored for any input text.
63+
// The maximum number of chunks that are stored for any input text.
6464
// If the configured chunker chunks the text into more chunks, each
6565
// chunk is sent to the inference service separately, but the results
6666
// are merged so that only this maximum number of chunks is stored.
@@ -112,7 +112,8 @@ public EmbeddingRequestChunker(
112112
chunkingSettings = defaultChunkingSettings;
113113
}
114114
Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker);
115-
List<ChunkOffset> chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings);
115+
String inputString = inputs.get(inputIndex).input();
116+
List<ChunkOffset> chunks = chunker.chunk(inputString, chunkingSettings);
116117
int resultCount = Math.min(chunks.size(), MAX_CHUNKS);
117118
resultEmbeddings.add(new AtomicReferenceArray<>(resultCount));
118119
resultOffsetStarts.add(new ArrayList<>(resultCount));
@@ -129,7 +130,7 @@ public EmbeddingRequestChunker(
129130
} else {
130131
resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end());
131132
}
132-
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs));
133+
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString));
133134
}
134135
}
135136

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,63 +8,40 @@
88
package org.elasticsearch.xpack.inference.external.http.sender;
99

1010
import org.elasticsearch.core.Nullable;
11-
import org.elasticsearch.inference.ChunkInferenceInput;
12-
import org.elasticsearch.inference.ChunkingSettings;
1311
import org.elasticsearch.inference.InputType;
1412

1513
import java.util.List;
1614
import java.util.Objects;
1715
import java.util.function.Supplier;
18-
import java.util.stream.Collectors;
1916

2017
public class EmbeddingsInput extends InferenceInputs {
21-
22-
public static EmbeddingsInput of(InferenceInputs inferenceInputs) {
23-
if (inferenceInputs instanceof EmbeddingsInput == false) {
24-
throw createUnsupportedTypeException(inferenceInputs, EmbeddingsInput.class);
25-
}
26-
27-
return (EmbeddingsInput) inferenceInputs;
28-
}
29-
30-
private final Supplier<List<ChunkInferenceInput>> listSupplier;
18+
private Supplier<List<String>> inputListSupplier;
3119
private final InputType inputType;
3220

33-
public EmbeddingsInput(Supplier<List<ChunkInferenceInput>> inputSupplier, @Nullable InputType inputType) {
34-
super(false);
35-
this.listSupplier = Objects.requireNonNull(inputSupplier);
36-
this.inputType = inputType;
21+
public EmbeddingsInput(List<String> input, @Nullable InputType inputType) {
22+
this(() -> input, inputType, false);
3723
}
3824

39-
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) {
40-
this(input, chunkingSettings, inputType, false);
25+
public EmbeddingsInput(List<String> input, @Nullable InputType inputType, boolean stream) {
26+
this(() -> input, inputType, stream);
4127
}
4228

43-
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) {
44-
this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream);
29+
public EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType) {
30+
this(inputSupplier, inputType, false);
4531
}
4632

47-
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
48-
this(input, inputType, false);
49-
}
50-
51-
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType, boolean stream) {
33+
private EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType, boolean stream) {
5234
super(stream);
53-
Objects.requireNonNull(input);
54-
this.listSupplier = () -> input;
35+
this.inputListSupplier = Objects.requireNonNull(inputSupplier);
5536
this.inputType = inputType;
5637
}
5738

58-
public List<ChunkInferenceInput> getInputs() {
59-
return this.listSupplier.get();
60-
}
61-
62-
public static EmbeddingsInput fromStrings(List<String> input, @Nullable InputType inputType) {
63-
return new EmbeddingsInput(input, null, inputType);
64-
}
65-
66-
public List<String> getStringInputs() {
67-
return getInputs().stream().map(ChunkInferenceInput::input).collect(Collectors.toList());
39+
public List<String> getInputs() {
40+
// The supplier should only be invoked once
41+
assert inputListSupplier != null;
42+
List<String> strings = inputListSupplier.get();
43+
inputListSupplier = null;
44+
return strings;
6845
}
6946

7047
public InputType getInputType() {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public void execute(
5252
Supplier<Boolean> hasRequestCompletedFunction,
5353
ActionListener<InferenceServiceResults> listener
5454
) {
55-
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs();
55+
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs();
5656
var truncatedInput = truncate(docsInput, maxInputTokens);
5757
var request = requestCreator.apply(truncatedInput);
5858

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ private static InferenceInputs createInput(
105105
if (validationException.validationErrors().isEmpty() == false) {
106106
throw validationException;
107107
}
108-
yield new EmbeddingsInput(input, null, inputType, stream);
108+
yield new EmbeddingsInput(input, inputType, stream);
109109
}
110110
default -> throw new ElasticsearchStatusException(
111111
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),
@@ -144,7 +144,7 @@ public void chunkedInfer(
144144
}
145145

146146
// a non-null query is not supported and is dropped by all providers
147-
doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, inputType, timeout, listener);
147+
doChunkedInfer(model, input, taskSettings, inputType, timeout, listener);
148148
}
149149

150150
protected abstract void doInfer(
@@ -168,7 +168,7 @@ protected abstract void doUnifiedCompletionInfer(
168168

169169
protected abstract void doChunkedInfer(
170170
Model model,
171-
EmbeddingsInput inputs,
171+
List<ChunkInferenceInput> inputs,
172172
Map<String, Object> taskSettings,
173173
InputType inputType,
174174
TimeValue timeout,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.common.util.LazyInitializable;
1616
import org.elasticsearch.core.Nullable;
1717
import org.elasticsearch.core.TimeValue;
18+
import org.elasticsearch.inference.ChunkInferenceInput;
1819
import org.elasticsearch.inference.ChunkedInference;
1920
import org.elasticsearch.inference.InferenceServiceConfiguration;
2021
import org.elasticsearch.inference.InferenceServiceExtension;
@@ -29,7 +30,6 @@
2930
import org.elasticsearch.rest.RestStatus;
3031
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
3132
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
32-
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
3333
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
3434
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3535
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
@@ -143,7 +143,7 @@ protected void doUnifiedCompletionInfer(
143143
@Override
144144
protected void doChunkedInfer(
145145
Model model,
146-
EmbeddingsInput inputs,
146+
List<ChunkInferenceInput> inputs,
147147
Map<String, Object> taskSettings,
148148
InputType inputType,
149149
TimeValue timeout,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ public void execute(
7171
Supplier<Boolean> hasRequestCompletedFunction,
7272
ActionListener<InferenceServiceResults> listener
7373
) {
74-
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
75-
List<String> docsInput = input.getStringInputs();
74+
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
75+
76+
List<String> docsInput = input.getInputs();
7677
InputType inputType = input.getInputType();
7778

7879
AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.core.Nullable;
1818
import org.elasticsearch.core.Strings;
1919
import org.elasticsearch.core.TimeValue;
20+
import org.elasticsearch.inference.ChunkInferenceInput;
2021
import org.elasticsearch.inference.ChunkedInference;
2122
import org.elasticsearch.inference.ChunkingSettings;
2223
import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -336,7 +337,7 @@ protected void validateRerankParameters(Boolean returnDocuments, Integer topN, V
336337
@Override
337338
protected void doChunkedInfer(
338339
Model model,
339-
EmbeddingsInput inputs,
340+
List<ChunkInferenceInput> inputs,
340341
Map<String, Object> taskSettings,
341342
InputType inputType,
342343
TimeValue timeout,
@@ -351,14 +352,14 @@ protected void doChunkedInfer(
351352
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());
352353

353354
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
354-
inputs.getInputs(),
355+
inputs,
355356
EMBEDDING_MAX_BATCH_SIZE,
356357
alibabaCloudSearchModel.getConfigurations().getChunkingSettings()
357358
).batchRequestsWithListeners(listener);
358359

359360
for (var request : batchedRequests) {
360361
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings);
361-
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
362+
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
362363
}
363364
}
364365

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ public void execute(
7171
Supplier<Boolean> hasRequestCompletedFunction,
7272
ActionListener<InferenceServiceResults> listener
7373
) {
74-
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
75-
List<String> docsInput = input.getStringInputs();
74+
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
75+
List<String> docsInput = input.getInputs();
7676
InputType inputType = input.getInputType();
7777

7878
AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ public void execute(
5656
Supplier<Boolean> hasRequestCompletedFunction,
5757
ActionListener<InferenceServiceResults> listener
5858
) {
59-
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
60-
List<String> docsInput = input.getStringInputs();
59+
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
60+
List<String> docsInput = input.getInputs();
6161
InputType inputType = input.getInputType();
6262

6363
var serviceSettings = embeddingsModel.getServiceSettings();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.core.IOUtils;
1919
import org.elasticsearch.core.Nullable;
2020
import org.elasticsearch.core.TimeValue;
21+
import org.elasticsearch.inference.ChunkInferenceInput;
2122
import org.elasticsearch.inference.ChunkedInference;
2223
import org.elasticsearch.inference.ChunkingSettings;
2324
import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -148,7 +149,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
148149
@Override
149150
protected void doChunkedInfer(
150151
Model model,
151-
EmbeddingsInput inputs,
152+
List<ChunkInferenceInput> inputs,
152153
Map<String, Object> taskSettings,
153154
InputType inputType,
154155
TimeValue timeout,
@@ -159,14 +160,14 @@ protected void doChunkedInfer(
159160
var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());
160161

161162
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
162-
inputs.getInputs(),
163+
inputs,
163164
maxBatchSize,
164165
baseAmazonBedrockModel.getConfigurations().getChunkingSettings()
165166
).batchRequestsWithListeners(listener);
166167

167168
for (var request : batchedRequests) {
168169
var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);
169-
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
170+
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
170171
}
171172
} else {
172173
listener.onFailure(createInvalidModelException(model));

0 commit comments

Comments
 (0)