Skip to content

Commit 04f38ca

Browse files
authored
[ML] Delay copying chunked input strings (elastic#125837) (elastic#126402)
The chunked text is only required when the actual inference request is made, using a string supplier means the string creation can be done much much closer to where the request is made reducing the lifespan of the copied string. (cherry picked from commit c521264) # Conflicts: # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java
1 parent cb83936 commit 04f38ca

File tree

11 files changed

+79
-62
lines changed

11 files changed

+79
-62
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.List;
2424
import java.util.concurrent.atomic.AtomicInteger;
2525
import java.util.concurrent.atomic.AtomicReferenceArray;
26+
import java.util.function.Supplier;
2627
import java.util.stream.Collectors;
2728

2829
/**
@@ -40,14 +41,14 @@ public class EmbeddingRequestChunker<E extends EmbeddingResults.Embedding<E>> {
4041

4142
// Visible for testing
4243
record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List<String> inputs) {
43-
public String chunkText() {
44+
String chunkText() {
4445
return inputs.get(inputIndex).substring(chunk.start(), chunk.end());
4546
}
4647
}
4748

4849
public record BatchRequest(List<Request> requests) {
49-
public List<String> inputs() {
50-
return requests.stream().map(Request::chunkText).collect(Collectors.toList());
50+
public Supplier<List<String>> inputs() {
51+
return () -> requests.stream().map(Request::chunkText).collect(Collectors.toList());
5152
}
5253
}
5354

@@ -144,7 +145,7 @@ public List<BatchRequestAndListener> batchRequestsWithListeners(ActionListener<L
144145
*/
145146
private class DebatchingListener implements ActionListener<InferenceServiceResults> {
146147

147-
private final BatchRequest request;
148+
private BatchRequest request;
148149

149150
DebatchingListener(BatchRequest request) {
150151
this.request = request;
@@ -170,6 +171,7 @@ public void onResponse(InferenceServiceResults inferenceServiceResults) {
170171
oldEmbedding -> oldEmbedding == null ? newEmbedding : oldEmbedding.merge(newEmbedding)
171172
);
172173
}
174+
request = null;
173175
if (resultCount.incrementAndGet() == batchRequests.size()) {
174176
sendFinalResponse();
175177
}
@@ -197,6 +199,7 @@ public void onFailure(Exception e) {
197199
for (Request request : request.requests) {
198200
resultsErrors.set(request.inputIndex(), e);
199201
}
202+
this.request = null;
200203
if (resultCount.incrementAndGet() == batchRequests.size()) {
201204
sendFinalResponse();
202205
}
@@ -208,6 +211,7 @@ private void sendFinalResponse() {
208211
for (int i = 0; i < resultEmbeddings.size(); i++) {
209212
if (resultsErrors.get(i) != null) {
210213
response.add(new ChunkedInferenceError(resultsErrors.get(i)));
214+
resultsErrors.set(i, null);
211215
} else {
212216
response.add(mergeResultsWithInputs(i));
213217
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public SingleInputSenderExecutableAction(
3333

3434
@Override
3535
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
36-
if (inferenceInputs.inputSize() > 1) {
36+
if (inferenceInputs.isSingleInput() == false) {
3737
listener.onFailure(
3838
new ElasticsearchStatusException(requestTypeForInputValidationError + " only accepts 1 input", RestStatus.BAD_REQUEST)
3939
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ public List<String> getInputs() {
3535
return this.input;
3636
}
3737

38-
public int inputSize() {
39-
return input.size();
38+
@Override
39+
public boolean isSingleInput() {
40+
return input.size() == 1;
4041
}
4142
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import java.util.List;
1414
import java.util.Objects;
15+
import java.util.function.Supplier;
1516

1617
public class EmbeddingsInput extends InferenceInputs {
1718

@@ -23,29 +24,38 @@ public static EmbeddingsInput of(InferenceInputs inferenceInputs) {
2324
return (EmbeddingsInput) inferenceInputs;
2425
}
2526

26-
private final List<String> input;
27-
27+
private final Supplier<List<String>> listSupplier;
2828
private final InputType inputType;
2929

3030
public EmbeddingsInput(List<String> input, @Nullable InputType inputType) {
3131
this(input, inputType, false);
3232
}
3333

34+
public EmbeddingsInput(Supplier<List<String>> inputSupplier, @Nullable InputType inputType) {
35+
super(false);
36+
this.listSupplier = Objects.requireNonNull(inputSupplier);
37+
this.inputType = inputType;
38+
}
39+
3440
public EmbeddingsInput(List<String> input, @Nullable InputType inputType, boolean stream) {
3541
super(stream);
36-
this.input = Objects.requireNonNull(input);
42+
Objects.requireNonNull(input);
43+
this.listSupplier = () -> input;
3744
this.inputType = inputType;
3845
}
3946

4047
public List<String> getInputs() {
41-
return this.input;
48+
return this.listSupplier.get();
4249
}
4350

4451
public InputType getInputType() {
4552
return this.inputType;
4653
}
4754

48-
public int inputSize() {
49-
return input.size();
55+
@Override
56+
public boolean isSingleInput() {
57+
// We can't measure the size of the input list without executing
58+
// the supplier.
59+
return false;
5060
}
5161
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,5 @@ public boolean stream() {
3434
return stream;
3535
}
3636

37-
public abstract int inputSize();
37+
public abstract boolean isSingleInput();
3838
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ public Integer getTopN() {
6161
return topN;
6262
}
6363

64-
public int inputSize() {
65-
return chunks.size();
64+
@Override
65+
public boolean isSingleInput() {
66+
return chunks.size() == 1;
6667
}
6768
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public UnifiedCompletionRequest getRequest() {
4949
return request;
5050
}
5151

52-
public int inputSize() {
53-
return request.messages().size();
52+
public boolean isSingleInput() {
53+
return request.messages().size() == 1;
5454
}
5555
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,7 @@ private void executeRequest(int batchIndex, boolean maybeDeploy, Runnable runAft
11131113
var inferenceRequest = buildInferenceRequest(
11141114
esModel.mlNodeDeploymentId(),
11151115
EmptyConfigUpdate.INSTANCE,
1116-
batch.batch().inputs(),
1116+
batch.batch().inputs().get(),
11171117
inputType,
11181118
timeout
11191119
);

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -49,53 +49,53 @@ public void testWhitespaceInput_SentenceChunker() {
4949
var batches = new EmbeddingRequestChunker<>(List.of(" "), 10, new SentenceBoundaryChunkingSettings(250, 1))
5050
.batchRequestsWithListeners(testListener());
5151
assertThat(batches, hasSize(1));
52-
assertThat(batches.get(0).batch().inputs(), hasSize(1));
53-
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" "));
52+
assertThat(batches.get(0).batch().inputs().get(), hasSize(1));
53+
assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is(" "));
5454
}
5555

5656
public void testBlankInput_WordChunker() {
5757
var batches = new EmbeddingRequestChunker<>(List.of(""), 100, 100, 10).batchRequestsWithListeners(testListener());
5858
assertThat(batches, hasSize(1));
59-
assertThat(batches.get(0).batch().inputs(), hasSize(1));
60-
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
59+
assertThat(batches.get(0).batch().inputs().get(), hasSize(1));
60+
assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is(""));
6161
}
6262

6363
public void testBlankInput_SentenceChunker() {
6464
var batches = new EmbeddingRequestChunker<>(List.of(""), 10, new SentenceBoundaryChunkingSettings(250, 1))
6565
.batchRequestsWithListeners(testListener());
6666
assertThat(batches, hasSize(1));
67-
assertThat(batches.get(0).batch().inputs(), hasSize(1));
68-
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
67+
assertThat(batches.get(0).batch().inputs().get(), hasSize(1));
68+
assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is(""));
6969
}
7070

7171
public void testInputThatDoesNotChunk_WordChunker() {
7272
var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 100, 100, 10).batchRequestsWithListeners(testListener());
7373
assertThat(batches, hasSize(1));
74-
assertThat(batches.get(0).batch().inputs(), hasSize(1));
75-
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
74+
assertThat(batches.get(0).batch().inputs().get(), hasSize(1));
75+
assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA"));
7676
}
7777

7878
public void testInputThatDoesNotChunk_SentenceChunker() {
7979
var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 10, new SentenceBoundaryChunkingSettings(250, 1))
8080
.batchRequestsWithListeners(testListener());
8181
assertThat(batches, hasSize(1));
82-
assertThat(batches.get(0).batch().inputs(), hasSize(1));
83-
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
82+
assertThat(batches.get(0).batch().inputs().get(), hasSize(1));
83+
assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA"));
8484
}
8585

8686
public void testShortInputsAreSingleBatch() {
8787
String input = "one chunk";
8888
var batches = new EmbeddingRequestChunker<>(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener());
8989
assertThat(batches, hasSize(1));
90-
assertThat(batches.get(0).batch().inputs(), contains(input));
90+
assertThat(batches.get(0).batch().inputs().get(), contains(input));
9191
}
9292

9393
public void testMultipleShortInputsAreSingleBatch() {
9494
List<String> inputs = List.of("1st small", "2nd small", "3rd small");
9595
var batches = new EmbeddingRequestChunker<>(inputs, 100, 100, 10).batchRequestsWithListeners(testListener());
9696
assertThat(batches, hasSize(1));
9797
EmbeddingRequestChunker.BatchRequest batch = batches.get(0).batch();
98-
assertEquals(batch.inputs(), inputs);
98+
assertEquals(batch.inputs().get(), inputs);
9999
for (int i = 0; i < inputs.size(); i++) {
100100
var request = batch.requests().get(i);
101101
assertThat(request.chunkText(), equalTo(inputs.get(i)));
@@ -115,20 +115,20 @@ public void testManyInputsMakeManyBatches() {
115115

116116
var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener());
117117
assertThat(batches, hasSize(4));
118-
assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch));
119-
assertThat(batches.get(1).batch().inputs(), hasSize(maxNumInputsPerBatch));
120-
assertThat(batches.get(2).batch().inputs(), hasSize(maxNumInputsPerBatch));
121-
assertThat(batches.get(3).batch().inputs(), hasSize(1));
118+
assertThat(batches.get(0).batch().inputs().get(), hasSize(maxNumInputsPerBatch));
119+
assertThat(batches.get(1).batch().inputs().get(), hasSize(maxNumInputsPerBatch));
120+
assertThat(batches.get(2).batch().inputs().get(), hasSize(maxNumInputsPerBatch));
121+
assertThat(batches.get(3).batch().inputs().get(), hasSize(1));
122122

123-
assertEquals("input 0", batches.get(0).batch().inputs().get(0));
124-
assertEquals("input 9", batches.get(0).batch().inputs().get(9));
123+
assertEquals("input 0", batches.get(0).batch().inputs().get().get(0));
124+
assertEquals("input 9", batches.get(0).batch().inputs().get().get(9));
125125
assertThat(
126-
batches.get(1).batch().inputs(),
126+
batches.get(1).batch().inputs().get(),
127127
contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19")
128128
);
129-
assertEquals("input 20", batches.get(2).batch().inputs().get(0));
130-
assertEquals("input 29", batches.get(2).batch().inputs().get(9));
131-
assertThat(batches.get(3).batch().inputs(), contains("input 30"));
129+
assertEquals("input 20", batches.get(2).batch().inputs().get().get(0));
130+
assertEquals("input 29", batches.get(2).batch().inputs().get().get(9));
131+
assertThat(batches.get(3).batch().inputs().get(), contains("input 30"));
132132

133133
List<EmbeddingRequestChunker.Request> requests = batches.get(0).batch().requests();
134134
for (int i = 0; i < requests.size(); i++) {
@@ -151,20 +151,20 @@ public void testChunkingSettingsProvided() {
151151
var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, ChunkingSettingsTests.createRandomChunkingSettings())
152152
.batchRequestsWithListeners(testListener());
153153
assertThat(batches, hasSize(4));
154-
assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch));
155-
assertThat(batches.get(1).batch().inputs(), hasSize(maxNumInputsPerBatch));
156-
assertThat(batches.get(2).batch().inputs(), hasSize(maxNumInputsPerBatch));
157-
assertThat(batches.get(3).batch().inputs(), hasSize(1));
154+
assertThat(batches.get(0).batch().inputs().get(), hasSize(maxNumInputsPerBatch));
155+
assertThat(batches.get(1).batch().inputs().get(), hasSize(maxNumInputsPerBatch));
156+
assertThat(batches.get(2).batch().inputs().get(), hasSize(maxNumInputsPerBatch));
157+
assertThat(batches.get(3).batch().inputs().get(), hasSize(1));
158158

159-
assertEquals("input 0", batches.get(0).batch().inputs().get(0));
160-
assertEquals("input 9", batches.get(0).batch().inputs().get(9));
159+
assertEquals("input 0", batches.get(0).batch().inputs().get().get(0));
160+
assertEquals("input 9", batches.get(0).batch().inputs().get().get(9));
161161
assertThat(
162-
batches.get(1).batch().inputs(),
162+
batches.get(1).batch().inputs().get(),
163163
contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19")
164164
);
165-
assertEquals("input 20", batches.get(2).batch().inputs().get(0));
166-
assertEquals("input 29", batches.get(2).batch().inputs().get(9));
167-
assertThat(batches.get(3).batch().inputs(), contains("input 30"));
165+
assertEquals("input 20", batches.get(2).batch().inputs().get().get(0));
166+
assertEquals("input 29", batches.get(2).batch().inputs().get().get(9));
167+
assertThat(batches.get(3).batch().inputs().get(), contains("input 30"));
168168

169169
List<EmbeddingRequestChunker.Request> requests = batches.get(0).batch().requests();
170170
for (int i = 0; i < requests.size(); i++) {
@@ -195,7 +195,7 @@ public void testLongInputChunkedOverMultipleBatches() {
195195
assertThat(batches, hasSize(2));
196196

197197
var batch = batches.get(0).batch();
198-
assertThat(batch.inputs(), hasSize(batchSize));
198+
assertThat(batch.inputs().get(), hasSize(batchSize));
199199
assertThat(batch.requests(), hasSize(batchSize));
200200

201201
EmbeddingRequestChunker.Request request = batch.requests().get(0);
@@ -212,7 +212,7 @@ public void testLongInputChunkedOverMultipleBatches() {
212212
}
213213

214214
batch = batches.get(1).batch();
215-
assertThat(batch.inputs(), hasSize(4));
215+
assertThat(batch.inputs().get(), hasSize(4));
216216
assertThat(batch.requests(), hasSize(4));
217217

218218
for (int requestIndex = 0; requestIndex < 2; requestIndex++) {
@@ -254,9 +254,9 @@ public void testVeryLongInput_Sparse() {
254254
// there are 10002 inference requests, resulting in 2001 batches.
255255
assertThat(batches, hasSize(2001));
256256
for (int i = 0; i < 2000; i++) {
257-
assertThat(batches.get(i).batch().inputs(), hasSize(5));
257+
assertThat(batches.get(i).batch().inputs().get(), hasSize(5));
258258
}
259-
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
259+
assertThat(batches.get(2000).batch().inputs().get(), hasSize(2));
260260

261261
// Produce inference results for each request, with just the token
262262
// "word" and increasing weights.
@@ -339,9 +339,9 @@ public void testVeryLongInput_Float() {
339339
// there are 10002 inference requests, resulting in 2001 batches.
340340
assertThat(batches, hasSize(2001));
341341
for (int i = 0; i < 2000; i++) {
342-
assertThat(batches.get(i).batch().inputs(), hasSize(5));
342+
assertThat(batches.get(i).batch().inputs().get(), hasSize(5));
343343
}
344-
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
344+
assertThat(batches.get(2000).batch().inputs().get(), hasSize(2));
345345

346346
// Produce inference results for each request, with increasing weights.
347347
float weight = 0f;
@@ -423,9 +423,9 @@ public void testVeryLongInput_Byte() {
423423
// there are 10002 inference requests, resulting in 2001 batches.
424424
assertThat(batches, hasSize(2001));
425425
for (int i = 0; i < 2000; i++) {
426-
assertThat(batches.get(i).batch().inputs(), hasSize(5));
426+
assertThat(batches.get(i).batch().inputs().get(), hasSize(5));
427427
}
428-
assertThat(batches.get(2000).batch().inputs(), hasSize(2));
428+
assertThat(batches.get(2000).batch().inputs().get(), hasSize(2));
429429

430430
// Produce inference results for each request, with increasing weights.
431431
byte weight = 0;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
1818
import org.elasticsearch.xpack.inference.external.http.sender.RequestManager;
1919
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
20+
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
2021
import org.junit.Before;
2122

2223
import java.util.List;
@@ -53,7 +54,7 @@ public void testOneInputIsValid() {
5354
var testRan = new AtomicBoolean(false);
5455

5556
executableAction.execute(
56-
mock(EmbeddingsInput.class),
57+
new UnifiedChatInput(List.of("one"), "system", false),
5758
mock(TimeValue.class),
5859
ActionListener.wrap(success -> testRan.set(true), e -> fail(e, "Test failed."))
5960
);
@@ -65,7 +66,7 @@ public void testMoreThanOneInput() {
6566
var badInput = mock(EmbeddingsInput.class);
6667
var input = List.of("one", "two");
6768
when(badInput.getInputs()).thenReturn(input);
68-
when(badInput.inputSize()).thenReturn(input.size());
69+
when(badInput.isSingleInput()).thenReturn(false);
6970
var actualException = new AtomicReference<Exception>();
7071

7172
executableAction.execute(

0 commit comments

Comments
 (0)