Skip to content

Commit 1db5f0d

Browse files
authored
[9.1] [ML] Restore delayed string copying (elastic#135242) (elastic#135464)
This commit restores the behaviour introduced in elastic#125837 which was inadvertently undone by changes in elastic#121041, specifically, delaying copying Strings as part of calling Request.chunkText() until the request is being executed. In addition to the above change, refactor doChunkedInfer() and its implementations to take a List<ChunkInferenceInput> rather than EmbeddingsInput, since the EmbeddingsInput passed into doChunkedInfer() was immediately discarded after extracting the ChunkInferenceInput list from it. This change allowed the EmbeddingsInput class to be refactored to not know about ChunkInferenceInput, simplifying it significantly. This commit also simplifies EmbeddingRequestChunker.Request to take only the input String rather than the entire list of all inputs, since only one input is actually needed. This change prevents Requests from retaining a reference to the input list, potentially allowing it to be GC'd faster. (cherry picked from commit f3447d3)
1 parent b10de03 commit 1db5f0d

File tree

69 files changed

+291
-297
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+291
-297
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@
3737
* a single large input that has been chunked may spread over
3838
* multiple batches.
3939
*
40-
* The final aspect it to gather the responses from the batch
40+
* The final aspect is to gather the responses from the batch
4141
* processing and map the results back to the original element
4242
* in the input list.
4343
*/
4444
public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
4545

4646
// Visible for testing
47-
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<ChunkInferenceInput> inputs) {
47+
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input) {
4848
public String chunkText() {
49-
return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end());
49+
return input.substring(chunk.start(), chunk.end());
5050
}
5151
}
5252

@@ -60,7 +60,7 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener<Inferen
6060

6161
private static final ChunkingSettings DEFAULT_CHUNKING_SETTINGS = new WordBoundaryChunkingSettings(250, 100);
6262

63-
// The maximum number of chunks that is stored for any input text.
63+
// The maximum number of chunks that are stored for any input text.
6464
// If the configured chunker chunks the text into more chunks, each
6565
// chunk is sent to the inference service separately, but the results
6666
// are merged so that only this maximum number of chunks is stored.
@@ -112,7 +112,8 @@ public EmbeddingRequestChunker(
112112
chunkingSettings = defaultChunkingSettings;
113113
}
114114
Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker);
115-
List<ChunkOffset> chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings);
115+
String inputString = inputs.get(inputIndex).input();
116+
List<ChunkOffset> chunks = chunker.chunk(inputString, chunkingSettings);
116117
int resultCount = Math.min(chunks.size(), MAX_CHUNKS);
117118
resultEmbeddings.add(new AtomicReferenceArray<>(resultCount));
118119
resultOffsetStarts.add(new ArrayList<>(resultCount));
@@ -129,7 +130,7 @@ public EmbeddingRequestChunker(
129130
} else {
130131
resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end());
131132
}
132-
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs));
133+
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString));
133134
}
134135
}
135136

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,63 +8,47 @@
88
package org.elasticsearch.xpack.inference.external.http.sender;
99

1010
import org.elasticsearch.core.Nullable;
11-
import org.elasticsearch.inference.ChunkInferenceInput;
12-
import org.elasticsearch.inference.ChunkingSettings;
1311
import org.elasticsearch.inference.InputType;
1412

1513
import java.util.List;
1614
import java.util.Objects;
15+
import java.util.concurrent.atomic.AtomicBoolean;
1716
import java.util.function.Supplier;
18-
import java.util.stream.Collectors;
1917

2018
public class EmbeddingsInput extends InferenceInputs {
21-
22-
public static EmbeddingsInput of(InferenceInputs inferenceInputs) {
23-
if (inferenceInputs instanceof EmbeddingsInput == false) {
24-
throw createUnsupportedTypeException(inferenceInputs, EmbeddingsInput.class);
25-
}
26-
27-
return (EmbeddingsInput) inferenceInputs;
28-
}
29-
30-
private final Supplier<List<ChunkInferenceInput>> listSupplier;
19+
private final Supplier<List<String>> inputListSupplier;
3120
private final InputType inputType;
21+
private final AtomicBoolean supplierInvoked = new AtomicBoolean();
3222

33-
public EmbeddingsInput(Supplier<List<ChunkInferenceInput>> inputSupplier, @Nullable InputType inputType) {
34-
super(false);
35-
this.listSupplier = Objects.requireNonNull(inputSupplier);
36-
this.inputType = inputType;
23+
public EmbeddingsInput(List<String> input, @Nullable InputType inputType) {
24+
this(() -> input, inputType, false);
3725
}
3826

39-
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) {
40-
this(input, chunkingSettings, inputType, false);
27+
public EmbeddingsInput(List<String> input, @Nullable InputType inputType, boolean stream) {
28+
this(() -> input, inputType, stream);
4129
}
4230

43-
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) {
44-
this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream);
31+
public EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType) {
32+
this(inputSupplier, inputType, false);
4533
}
4634

47-
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
48-
this(input, inputType, false);
49-
}
50-
51-
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType, boolean stream) {
35+
private EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType, boolean stream) {
5236
super(stream);
53-
Objects.requireNonNull(input);
54-
this.listSupplier = () -> input;
37+
this.inputListSupplier = Objects.requireNonNull(inputSupplier);
5538
this.inputType = inputType;
5639
}
5740

58-
public List<ChunkInferenceInput> getInputs() {
59-
return this.listSupplier.get();
60-
}
61-
62-
public static EmbeddingsInput fromStrings(List<String> input, @Nullable InputType inputType) {
63-
return new EmbeddingsInput(input, null, inputType);
64-
}
65-
66-
public List<String> getStringInputs() {
67-
return getInputs().stream().map(ChunkInferenceInput::input).collect(Collectors.toList());
41+
/**
42+
* Calling this method twice will result in the {@link #inputListSupplier} being invoked twice. In the case where the supplier simply
43+
* returns the list passed into the constructor, this is not a problem, but in the case where a supplier that will chunk the input
44+
* Strings when invoked is passed into the constructor, this will result in multiple copies of the input Strings being created. Calling
45+
* this method twice in a non-production environment will cause an {@link AssertionError} to be thrown.
46+
*
47+
* @return a list of String embedding inputs
48+
*/
49+
public List<String> getInputs() {
50+
assert supplierInvoked.compareAndSet(false, true) : "EmbeddingsInput supplier invoked twice";
51+
return inputListSupplier.get();
6852
}
6953

7054
public InputType getInputType() {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public void execute(
5252
Supplier<Boolean> hasRequestCompletedFunction,
5353
ActionListener<InferenceServiceResults> listener
5454
) {
55-
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs();
55+
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs();
5656
var truncatedInput = truncate(docsInput, maxInputTokens);
5757
var request = requestCreator.apply(truncatedInput);
5858

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ private static InferenceInputs createInput(
101101
if (validationException.validationErrors().isEmpty() == false) {
102102
throw validationException;
103103
}
104-
yield new EmbeddingsInput(input, null, inputType, stream);
104+
yield new EmbeddingsInput(input, inputType, stream);
105105
}
106106
default -> throw new ElasticsearchStatusException(
107107
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),
@@ -140,7 +140,7 @@ public void chunkedInfer(
140140
}
141141

142142
// a non-null query is not supported and is dropped by all providers
143-
doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, inputType, timeout, listener);
143+
doChunkedInfer(model, input, taskSettings, inputType, timeout, listener);
144144
}
145145

146146
protected abstract void doInfer(
@@ -164,7 +164,7 @@ protected abstract void doUnifiedCompletionInfer(
164164

165165
protected abstract void doChunkedInfer(
166166
Model model,
167-
EmbeddingsInput inputs,
167+
List<ChunkInferenceInput> inputs,
168168
Map<String, Object> taskSettings,
169169
InputType inputType,
170170
TimeValue timeout,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,9 @@ public void execute(
7171
Supplier<Boolean> hasRequestCompletedFunction,
7272
ActionListener<InferenceServiceResults> listener
7373
) {
74-
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
75-
List<String> docsInput = input.getStringInputs();
74+
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
75+
76+
List<String> docsInput = input.getInputs();
7677
InputType inputType = input.getInputType();
7778

7879
AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.core.Nullable;
1717
import org.elasticsearch.core.Strings;
1818
import org.elasticsearch.core.TimeValue;
19+
import org.elasticsearch.inference.ChunkInferenceInput;
1920
import org.elasticsearch.inference.ChunkedInference;
2021
import org.elasticsearch.inference.ChunkingSettings;
2122
import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -321,7 +322,7 @@ protected void validateRerankParameters(Boolean returnDocuments, Integer topN, V
321322
@Override
322323
protected void doChunkedInfer(
323324
Model model,
324-
EmbeddingsInput inputs,
325+
List<ChunkInferenceInput> inputs,
325326
Map<String, Object> taskSettings,
326327
InputType inputType,
327328
TimeValue timeout,
@@ -336,14 +337,14 @@ protected void doChunkedInfer(
336337
var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents());
337338

338339
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
339-
inputs.getInputs(),
340+
inputs,
340341
EMBEDDING_MAX_BATCH_SIZE,
341342
alibabaCloudSearchModel.getConfigurations().getChunkingSettings()
342343
).batchRequestsWithListeners(listener);
343344

344345
for (var request : batchedRequests) {
345346
var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings);
346-
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
347+
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
347348
}
348349
}
349350

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ public void execute(
7171
Supplier<Boolean> hasRequestCompletedFunction,
7272
ActionListener<InferenceServiceResults> listener
7373
) {
74-
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
75-
List<String> docsInput = input.getStringInputs();
74+
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
75+
List<String> docsInput = input.getInputs();
7676
InputType inputType = input.getInputType();
7777

7878
AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ public void execute(
5656
Supplier<Boolean> hasRequestCompletedFunction,
5757
ActionListener<InferenceServiceResults> listener
5858
) {
59-
EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs);
60-
List<String> docsInput = input.getStringInputs();
59+
EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class);
60+
List<String> docsInput = input.getInputs();
6161
InputType inputType = input.getInputType();
6262

6363
var serviceSettings = embeddingsModel.getServiceSettings();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.core.IOUtils;
1818
import org.elasticsearch.core.Nullable;
1919
import org.elasticsearch.core.TimeValue;
20+
import org.elasticsearch.inference.ChunkInferenceInput;
2021
import org.elasticsearch.inference.ChunkedInference;
2122
import org.elasticsearch.inference.ChunkingSettings;
2223
import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -136,7 +137,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
136137
@Override
137138
protected void doChunkedInfer(
138139
Model model,
139-
EmbeddingsInput inputs,
140+
List<ChunkInferenceInput> inputs,
140141
Map<String, Object> taskSettings,
141142
InputType inputType,
142143
TimeValue timeout,
@@ -147,14 +148,14 @@ protected void doChunkedInfer(
147148
var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());
148149

149150
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
150-
inputs.getInputs(),
151+
inputs,
151152
maxBatchSize,
152153
baseAmazonBedrockModel.getConfigurations().getChunkingSettings()
153154
).batchRequestsWithListeners(listener);
154155

155156
for (var request : batchedRequests) {
156157
var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);
157-
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
158+
action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener());
158159
}
159160
} else {
160161
listener.onFailure(createInvalidModelException(model));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.common.util.LazyInitializable;
1616
import org.elasticsearch.core.Nullable;
1717
import org.elasticsearch.core.TimeValue;
18+
import org.elasticsearch.inference.ChunkInferenceInput;
1819
import org.elasticsearch.inference.ChunkedInference;
1920
import org.elasticsearch.inference.InferenceServiceConfiguration;
2021
import org.elasticsearch.inference.InferenceServiceResults;
@@ -26,7 +27,6 @@
2627
import org.elasticsearch.inference.TaskType;
2728
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
2829
import org.elasticsearch.rest.RestStatus;
29-
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
3030
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
3131
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
3232
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -222,7 +222,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc
222222
@Override
223223
protected void doChunkedInfer(
224224
Model model,
225-
EmbeddingsInput inputs,
225+
List<ChunkInferenceInput> inputs,
226226
Map<String, Object> taskSettings,
227227
InputType inputType,
228228
TimeValue timeout,

0 commit comments

Comments
 (0)