Skip to content

Commit 133e694

Browse files
committed
[ML] Restore delayed string copying (elastic#135242)
This commit restores the behaviour introduced in elastic#125837 which was inadvertently undone by changes in elastic#121041, specifically, delaying copying Strings as part of calling Request.chunkText() until the request is being executed. In addition to the above change, refactor doChunkedInfer() and its implementations to take a List<ChunkInferenceInput> rather than EmbeddingsInput, since the EmbeddingsInput passed into doChunkedInfer() was immediately discarded after extracting the ChunkInferenceInput list from it. This change allowed the EmbeddingsInput class to be refactored to not know about ChunkInferenceInput, simplifying it significantly. This commit also simplifies EmbeddingRequestChunker.Request to take only the input String rather than the entire list of all inputs, since only one input is actually needed. This change prevents Requests from retaining a reference to the input list, potentially allowing it to be GC'd faster. (cherry picked from commit f3447d3) # Conflicts: # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchCompletionRequestManagerTests.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java
1 parent 6c02d66 commit 133e694

File tree

74 files changed

+1854
-300
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+1854
-300
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@
3737
* a single large input that has been chunked may spread over
3838
* multiple batches.
3939
*
40-
* The final aspect it to gather the responses from the batch
40+
* The final aspect is to gather the responses from the batch
4141
* processing and map the results back to the original element
4242
* in the input list.
4343
*/
4444
public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
4545

4646
// Visible for testing
47-
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<ChunkInferenceInput> inputs) {
47+
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input) {
4848
public String chunkText() {
49-
return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end());
49+
return input.substring(chunk.start(), chunk.end());
5050
}
5151
}
5252

@@ -60,7 +60,7 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener<Inferen
6060

6161
private static final ChunkingSettings DEFAULT_CHUNKING_SETTINGS = new WordBoundaryChunkingSettings(250, 100);
6262

63-
// The maximum number of chunks that is stored for any input text.
63+
// The maximum number of chunks that are stored for any input text.
6464
// If the configured chunker chunks the text into more chunks, each
6565
// chunk is sent to the inference service separately, but the results
6666
// are merged so that only this maximum number of chunks is stored.
@@ -112,7 +112,8 @@ public EmbeddingRequestChunker(
112112
chunkingSettings = defaultChunkingSettings;
113113
}
114114
Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker);
115-
List<ChunkOffset> chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings);
115+
String inputString = inputs.get(inputIndex).input();
116+
List<ChunkOffset> chunks = chunker.chunk(inputString, chunkingSettings);
116117
int resultCount = Math.min(chunks.size(), MAX_CHUNKS);
117118
resultEmbeddings.add(new AtomicReferenceArray<>(resultCount));
118119
resultOffsetStarts.add(new ArrayList<>(resultCount));
@@ -129,7 +130,7 @@ public EmbeddingRequestChunker(
129130
} else {
130131
resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end());
131132
}
132-
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs));
133+
allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString));
133134
}
134135
}
135136

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,63 +8,47 @@
88
package org.elasticsearch.xpack.inference.external.http.sender;
99

1010
import org.elasticsearch.core.Nullable;
11-
import org.elasticsearch.inference.ChunkInferenceInput;
12-
import org.elasticsearch.inference.ChunkingSettings;
1311
import org.elasticsearch.inference.InputType;
1412

1513
import java.util.List;
1614
import java.util.Objects;
15+
import java.util.concurrent.atomic.AtomicBoolean;
1716
import java.util.function.Supplier;
18-
import java.util.stream.Collectors;
1917

2018
public class EmbeddingsInput extends InferenceInputs {
21-
22-
public static EmbeddingsInput of(InferenceInputs inferenceInputs) {
23-
if (inferenceInputs instanceof EmbeddingsInput == false) {
24-
throw createUnsupportedTypeException(inferenceInputs, EmbeddingsInput.class);
25-
}
26-
27-
return (EmbeddingsInput) inferenceInputs;
28-
}
29-
30-
private final Supplier<List<ChunkInferenceInput>> listSupplier;
19+
private final Supplier<List<String>> inputListSupplier;
3120
private final InputType inputType;
21+
private final AtomicBoolean supplierInvoked = new AtomicBoolean();
3222

33-
public EmbeddingsInput(Supplier<List<ChunkInferenceInput>> inputSupplier, @Nullable InputType inputType) {
34-
super(false);
35-
this.listSupplier = Objects.requireNonNull(inputSupplier);
36-
this.inputType = inputType;
23+
public EmbeddingsInput(List<String> input, @Nullable InputType inputType) {
24+
this(() -> input, inputType, false);
3725
}
3826

39-
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) {
40-
this(input, chunkingSettings, inputType, false);
27+
public EmbeddingsInput(List<String> input, @Nullable InputType inputType, boolean stream) {
28+
this(() -> input, inputType, stream);
4129
}
4230

43-
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) {
44-
this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream);
31+
public EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType) {
32+
this(inputSupplier, inputType, false);
4533
}
4634

47-
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
48-
this(input, inputType, false);
49-
}
50-
51-
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType, boolean stream) {
35+
private EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType, boolean stream) {
5236
super(stream);
53-
Objects.requireNonNull(input);
54-
this.listSupplier = () -> input;
37+
this.inputListSupplier = Objects.requireNonNull(inputSupplier);
5538
this.inputType = inputType;
5639
}
5740

58-
public List<ChunkInferenceInput> getInputs() {
59-
return this.listSupplier.get();
60-
}
61-
62-
public static EmbeddingsInput fromStrings(List<String> input, @Nullable InputType inputType) {
63-
return new EmbeddingsInput(input, null, inputType);
64-
}
65-
66-
public List<String> getStringInputs() {
67-
return getInputs().stream().map(ChunkInferenceInput::input).collect(Collectors.toList());
41+
/**
42+
* Calling this method twice will result in the {@link #inputListSupplier} being invoked twice. In the case where the supplier simply
43+
* returns the list passed into the constructor, this is not a problem, but in the case where a supplier that will chunk the input
44+
* Strings when invoked is passed into the constructor, this will result in multiple copies of the input Strings being created. Calling
45+
* this method twice in a non-production environment will cause an {@link AssertionError} to be thrown.
46+
*
47+
* @return a list of String embedding inputs
48+
*/
49+
public List<String> getInputs() {
50+
assert supplierInvoked.compareAndSet(false, true) : "EmbeddingsInput supplier invoked twice";
51+
return inputListSupplier.get();
6852
}
6953

7054
public InputType getInputType() {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public void execute(
5252
Supplier<Boolean> hasRequestCompletedFunction,
5353
ActionListener<InferenceServiceResults> listener
5454
) {
55-
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs();
55+
var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs();
5656
var truncatedInput = truncate(docsInput, maxInputTokens);
5757
var request = requestCreator.apply(truncatedInput);
5858

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ private static InferenceInputs createInput(
101101
if (validationException.validationErrors().isEmpty() == false) {
102102
throw validationException;
103103
}
104-
yield new EmbeddingsInput(input, null, inputType, stream);
104+
yield new EmbeddingsInput(input, inputType, stream);
105105
}
106106
default -> throw new ElasticsearchStatusException(
107107
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),
@@ -140,7 +140,7 @@ public void chunkedInfer(
140140
}
141141

142142
// a non-null query is not supported and is dropped by all providers
143-
doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, inputType, timeout, listener);
143+
doChunkedInfer(model, input, taskSettings, inputType, timeout, listener);
144144
}
145145

146146
protected abstract void doInfer(
@@ -164,7 +164,7 @@ protected abstract void doUnifiedCompletionInfer(
164164

165165
protected abstract void doChunkedInfer(
166166
Model model,
167-
EmbeddingsInput inputs,
167+
List<ChunkInferenceInput> inputs,
168168
Map<String, Object> taskSettings,
169169
InputType inputType,
170170
TimeValue timeout,

0 commit comments

Comments
 (0)