Skip to content

Commit 2370870

Browse files
committed
Batch the chunks
1 parent 57532e7 commit 2370870

File tree

3 files changed

+192
-33
lines changed

3 files changed

+192
-33
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import java.util.Map;
6363
import java.util.Optional;
6464
import java.util.Set;
65+
import java.util.concurrent.atomic.AtomicInteger;
6566
import java.util.function.Consumer;
6667
import java.util.function.Function;
6768

@@ -668,25 +669,13 @@ public void chunkedInfer(
668669
).batchRequestsWithListeners(listener);
669670
}
670671

671-
for (var batch : batchedRequests) {
672-
var inferenceRequest = buildInferenceRequest(
673-
esModel.mlNodeDeploymentId(),
674-
EmptyConfigUpdate.INSTANCE,
675-
batch.batch().inputs(),
676-
inputType,
677-
timeout
678-
);
679-
680-
ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
681-
.delegateFailureAndWrap(
682-
(l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l)
683-
);
684-
685-
var maybeDeployListener = mlResultsListener.delegateResponse(
686-
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener)
687-
);
688-
689-
client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
672+
if (batchedRequests.isEmpty()) {
673+
listener.onResponse(List.of());
674+
} else {
675+
// Avoid filling the inference queue by executing the batches in series
676+
// Each batch contains up to EMBEDDING_MAX_BATCH_SIZE inference request
677+
var sequentialRunner = new BatchIterator(esModel, inputType, timeout, batchedRequests);
678+
sequentialRunner.doNextRequest();
690679
}
691680
} else {
692681
listener.onFailure(notElasticsearchModelException(model));
@@ -1004,4 +993,58 @@ static TaskType inferenceConfigToTaskType(InferenceConfig config) {
1004993
return null;
1005994
}
1006995
}
996+
997+
// Iterates over the batch sending 1 request at a time to avoid
998+
// filling the ml node inference queue.
999+
class BatchIterator {
1000+
private final AtomicInteger index = new AtomicInteger();
1001+
private final ElasticsearchInternalModel esModel;
1002+
private final List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners;
1003+
private final InputType inputType;
1004+
private final TimeValue timeout;
1005+
1006+
BatchIterator(
1007+
ElasticsearchInternalModel esModel,
1008+
InputType inputType,
1009+
TimeValue timeout,
1010+
List<EmbeddingRequestChunker.BatchRequestAndListener> requestAndListeners
1011+
) {
1012+
this.esModel = esModel;
1013+
this.requestAndListeners = requestAndListeners;
1014+
this.inputType = inputType;
1015+
this.timeout = timeout;
1016+
}
1017+
1018+
void doNextRequest() {
1019+
inferenceExecutor.execute(() -> inferOnBatch(requestAndListeners.get(index.get())));
1020+
}
1021+
1022+
private void inferOnBatch(EmbeddingRequestChunker.BatchRequestAndListener batch) {
1023+
var inferenceRequest = buildInferenceRequest(
1024+
esModel.mlNodeDeploymentId(),
1025+
EmptyConfigUpdate.INSTANCE,
1026+
batch.batch().inputs(),
1027+
inputType,
1028+
timeout
1029+
);
1030+
1031+
ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
1032+
.delegateFailureAndWrap(
1033+
(l, inferenceResult) -> translateToChunkedResult(esModel.getTaskType(), inferenceResult.getInferenceResults(), l)
1034+
);
1035+
1036+
// schedule the next request once the results have been processed
1037+
var scheduleNextListener = ActionListener.runAfter(mlResultsListener, () -> {
1038+
if (index.incrementAndGet() < requestAndListeners.size()) {
1039+
doNextRequest();
1040+
}
1041+
});
1042+
1043+
var maybeDeployListener = scheduleNextListener.delegateResponse(
1044+
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, scheduleNextListener)
1045+
);
1046+
1047+
client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
1048+
}
1049+
}
10071050
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,25 @@
2424
import java.util.concurrent.atomic.AtomicReference;
2525

2626
import static org.hamcrest.Matchers.contains;
27+
import static org.hamcrest.Matchers.empty;
2728
import static org.hamcrest.Matchers.hasSize;
2829
import static org.hamcrest.Matchers.instanceOf;
2930
import static org.hamcrest.Matchers.startsWith;
3031

3132
public class EmbeddingRequestChunkerTests extends ESTestCase {
3233

34+
public void testEmptyInput() {
35+
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
36+
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
37+
assertThat(batches, empty());
38+
}
39+
40+
public void testBlankInput() {
41+
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
42+
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
43+
assertThat(batches, hasSize(1));
44+
}
45+
3346
public void testShortInputsAreSingleBatch() {
3447
String input = "one chunk";
3548
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 117 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.logging.log4j.Level;
1313
import org.elasticsearch.ElasticsearchStatusException;
1414
import org.elasticsearch.action.ActionListener;
15+
import org.elasticsearch.action.LatchedActionListener;
1516
import org.elasticsearch.action.support.PlainActionFuture;
1617
import org.elasticsearch.client.internal.Client;
1718
import org.elasticsearch.cluster.service.ClusterService;
@@ -60,6 +61,7 @@
6061
import org.elasticsearch.xpack.inference.InferencePlugin;
6162
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
6263
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
64+
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
6365
import org.elasticsearch.xpack.inference.services.ServiceFields;
6466
import org.junit.After;
6567
import org.junit.Before;
@@ -73,6 +75,7 @@
7375
import java.util.Map;
7476
import java.util.Optional;
7577
import java.util.Set;
78+
import java.util.concurrent.CountDownLatch;
7679
import java.util.concurrent.atomic.AtomicBoolean;
7780
import java.util.concurrent.atomic.AtomicInteger;
7881
import java.util.concurrent.atomic.AtomicReference;
@@ -936,17 +939,17 @@ public void testParsePersistedConfig() {
936939
}
937940
}
938941

939-
public void testChunkInfer_E5WithNullChunkingSettings() {
942+
public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException {
940943
testChunkInfer_e5(null);
941944
}
942945

943-
public void testChunkInfer_E5ChunkingSettingsSetAndFeatureFlagEnabled() {
946+
public void testChunkInfer_E5ChunkingSettingsSetAndFeatureFlagEnabled() throws InterruptedException {
944947
assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled());
945948
testChunkInfer_e5(ChunkingSettingsTests.createRandomChunkingSettings());
946949
}
947950

948951
@SuppressWarnings("unchecked")
949-
private void testChunkInfer_e5(ChunkingSettings chunkingSettings) {
952+
private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws InterruptedException {
950953
var mlTrainedModelResults = new ArrayList<InferenceResults>();
951954
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
952955
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
@@ -994,6 +997,9 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) {
994997
gotResults.set(true);
995998
}, ESTestCase::fail);
996999

1000+
var latch = new CountDownLatch(1);
1001+
var latchedListener = new LatchedActionListener<>(resultsListener, latch);
1002+
9971003
service.chunkedInfer(
9981004
model,
9991005
null,
@@ -1002,23 +1008,24 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) {
10021008
InputType.SEARCH,
10031009
new ChunkingOptions(null, null),
10041010
InferenceAction.Request.DEFAULT_TIMEOUT,
1005-
ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
1011+
latchedListener
10061012
);
10071013

1014+
latch.await();
10081015
assertTrue("Listener not called", gotResults.get());
10091016
}
10101017

1011-
public void testChunkInfer_SparseWithNullChunkingSettings() {
1018+
public void testChunkInfer_SparseWithNullChunkingSettings() throws InterruptedException {
10121019
testChunkInfer_Sparse(null);
10131020
}
10141021

1015-
public void testChunkInfer_SparseWithChunkingSettingsSetAndFeatureFlagEnabled() {
1022+
public void testChunkInfer_SparseWithChunkingSettingsSetAndFeatureFlagEnabled() throws InterruptedException {
10161023
assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled());
10171024
testChunkInfer_Sparse(ChunkingSettingsTests.createRandomChunkingSettings());
10181025
}
10191026

10201027
@SuppressWarnings("unchecked")
1021-
private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
1028+
private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws InterruptedException {
10221029
var mlTrainedModelResults = new ArrayList<InferenceResults>();
10231030
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
10241031
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
@@ -1042,6 +1049,7 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
10421049
var service = createService(client);
10431050

10441051
var gotResults = new AtomicBoolean();
1052+
10451053
var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
10461054
assertThat(chunkedResponse, hasSize(2));
10471055
assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class));
@@ -1061,6 +1069,9 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
10611069
gotResults.set(true);
10621070
}, ESTestCase::fail);
10631071

1072+
var latch = new CountDownLatch(1);
1073+
var latchedListener = new LatchedActionListener<>(resultsListener, latch);
1074+
10641075
service.chunkedInfer(
10651076
model,
10661077
null,
@@ -1069,23 +1080,24 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
10691080
InputType.SEARCH,
10701081
new ChunkingOptions(null, null),
10711082
InferenceAction.Request.DEFAULT_TIMEOUT,
1072-
ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
1083+
latchedListener
10731084
);
10741085

1086+
latch.await();
10751087
assertTrue("Listener not called", gotResults.get());
10761088
}
10771089

1078-
public void testChunkInfer_ElserWithNullChunkingSettings() {
1090+
public void testChunkInfer_ElserWithNullChunkingSettings() throws InterruptedException {
10791091
testChunkInfer_Elser(null);
10801092
}
10811093

1082-
public void testChunkInfer_ElserWithChunkingSettingsSetAndFeatureFlagEnabled() {
1094+
public void testChunkInfer_ElserWithChunkingSettingsSetAndFeatureFlagEnabled() throws InterruptedException {
10831095
assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled());
10841096
testChunkInfer_Elser(ChunkingSettingsTests.createRandomChunkingSettings());
10851097
}
10861098

10871099
@SuppressWarnings("unchecked")
1088-
private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) {
1100+
private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws InterruptedException {
10891101
var mlTrainedModelResults = new ArrayList<InferenceResults>();
10901102
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
10911103
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
@@ -1129,6 +1141,9 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) {
11291141
gotResults.set(true);
11301142
}, ESTestCase::fail);
11311143

1144+
var latch = new CountDownLatch(1);
1145+
var latchedListener = new LatchedActionListener<>(resultsListener, latch);
1146+
11321147
service.chunkedInfer(
11331148
model,
11341149
null,
@@ -1137,9 +1152,10 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) {
11371152
InputType.SEARCH,
11381153
new ChunkingOptions(null, null),
11391154
InferenceAction.Request.DEFAULT_TIMEOUT,
1140-
ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
1155+
latchedListener
11411156
);
11421157

1158+
latch.await();
11431159
assertTrue("Listener not called", gotResults.get());
11441160
}
11451161

@@ -1200,7 +1216,7 @@ public void testChunkInferSetsTokenization() {
12001216
}
12011217

12021218
@SuppressWarnings("unchecked")
1203-
public void testChunkInfer_FailsBatch() {
1219+
public void testChunkInfer_FailsBatch() throws InterruptedException {
12041220
var mlTrainedModelResults = new ArrayList<InferenceResults>();
12051221
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
12061222
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
@@ -1236,6 +1252,9 @@ public void testChunkInfer_FailsBatch() {
12361252
gotResults.set(true);
12371253
}, ESTestCase::fail);
12381254

1255+
var latch = new CountDownLatch(1);
1256+
var latchedListener = new LatchedActionListener<>(resultsListener, latch);
1257+
12391258
service.chunkedInfer(
12401259
model,
12411260
null,
@@ -1244,9 +1263,93 @@ public void testChunkInfer_FailsBatch() {
12441263
InputType.SEARCH,
12451264
new ChunkingOptions(null, null),
12461265
InferenceAction.Request.DEFAULT_TIMEOUT,
1247-
ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
1266+
latchedListener
1267+
);
1268+
1269+
latch.await();
1270+
assertTrue("Listener not called", gotResults.get());
1271+
}
1272+
1273+
@SuppressWarnings("unchecked")
1274+
public void testChunkingLargeDocument() throws InterruptedException {
1275+
assumeTrue("Only if 'inference_chunking_settings' feature flag is enabled", ChunkingSettingsFeatureFlag.isEnabled());
1276+
1277+
int wordsPerChunk = 10;
1278+
int numBatches = randomIntBetween(3, 6);
1279+
int numChunks = randomIntBetween(
1280+
((numBatches - 1) * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE) + 1,
1281+
numBatches * ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE
1282+
);
1283+
1284+
// build a doc with enough words to make numChunks of chunks
1285+
int numWords = numChunks * wordsPerChunk;
1286+
var docBuilder = new StringBuilder();
1287+
for (int i = 0; i < numWords; i++) {
1288+
docBuilder.append("word ");
1289+
}
1290+
1291+
// how many response objects to return in each batch
1292+
int[] numResponsesPerBatch = new int[numBatches];
1293+
for (int i = 0; i < numBatches - 1; i++) {
1294+
numResponsesPerBatch[i] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
1295+
}
1296+
numResponsesPerBatch[numBatches - 1] = numChunks % ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
1297+
1298+
var batchIndex = new AtomicInteger();
1299+
Client client = mock(Client.class);
1300+
when(client.threadPool()).thenReturn(threadPool);
1301+
1302+
// mock the inference response
1303+
doAnswer(invocationOnMock -> {
1304+
var listener = (ActionListener<InferModelAction.Response>) invocationOnMock.getArguments()[2];
1305+
1306+
var mlTrainedModelResults = new ArrayList<InferenceResults>();
1307+
for (int i = 0; i < numResponsesPerBatch[batchIndex.get()]; i++) {
1308+
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
1309+
}
1310+
batchIndex.incrementAndGet();
1311+
var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);
1312+
listener.onResponse(response);
1313+
return null;
1314+
}).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class));
1315+
1316+
var service = createService(client);
1317+
1318+
var gotResults = new AtomicBoolean();
1319+
var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
1320+
assertThat(chunkedResponse, hasSize(1));
1321+
assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
1322+
var sparseResults = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0);
1323+
assertThat(sparseResults.chunks(), hasSize(numChunks));
1324+
1325+
gotResults.set(true);
1326+
}, ESTestCase::fail);
1327+
1328+
// Create model using the word boundary chunker.
1329+
var model = new MultilingualE5SmallModel(
1330+
"foo",
1331+
TaskType.TEXT_EMBEDDING,
1332+
"e5",
1333+
new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null),
1334+
new WordBoundaryChunkingSettings(wordsPerChunk, 0)
1335+
);
1336+
1337+
var latch = new CountDownLatch(1);
1338+
var latchedListener = new LatchedActionListener<>(resultsListener, latch);
1339+
1340+
// For the given input we know how many requests will be made
1341+
service.chunkedInfer(
1342+
model,
1343+
null,
1344+
List.of(docBuilder.toString()),
1345+
Map.of(),
1346+
InputType.SEARCH,
1347+
new ChunkingOptions(null, null),
1348+
InferenceAction.Request.DEFAULT_TIMEOUT,
1349+
latchedListener
12481350
);
12491351

1352+
latch.await();
12501353
assertTrue("Listener not called", gotResults.get());
12511354
}
12521355

0 commit comments

Comments
 (0)