Skip to content

Commit ce73a90

Browse files
authored
[ML] Use the same chunking configurations for models in the Elasticsearch service (#111336)
1 parent 4990276 commit ce73a90

File tree

7 files changed

+386
-162
lines changed

7 files changed

+386
-162
lines changed

docs/changelog/111336.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 111336
2+
summary: Use the same chunking configurations for models in the Elasticsearch service
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
import org.elasticsearch.inference.InferenceServiceResults;
1717
import org.elasticsearch.rest.RestStatus;
1818
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
19+
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
1920
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults;
2021
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
2122
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
2223
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
24+
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
25+
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults;
2326

2427
import java.util.ArrayList;
2528
import java.util.List;
@@ -42,7 +45,8 @@ public class EmbeddingRequestChunker {
4245

4346
public enum EmbeddingType {
4447
FLOAT,
45-
BYTE;
48+
BYTE,
49+
SPARSE;
4650

4751
public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.ElementType elementType) {
4852
return switch (elementType) {
@@ -67,6 +71,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El
6771
private List<List<String>> chunkedInputs;
6872
private List<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>> floatResults;
6973
private List<AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>>> byteResults;
74+
private List<AtomicArray<List<SparseEmbeddingResults.Embedding>>> sparseResults;
7075
private AtomicArray<ErrorChunkedInferenceResults> errors;
7176
private ActionListener<List<ChunkedInferenceServiceResults>> finalListener;
7277

@@ -117,6 +122,7 @@ private void splitIntoBatchedRequests(List<String> inputs) {
117122
switch (embeddingType) {
118123
case FLOAT -> floatResults = new ArrayList<>(inputs.size());
119124
case BYTE -> byteResults = new ArrayList<>(inputs.size());
125+
case SPARSE -> sparseResults = new ArrayList<>(inputs.size());
120126
}
121127
errors = new AtomicArray<>(inputs.size());
122128

@@ -127,6 +133,7 @@ private void splitIntoBatchedRequests(List<String> inputs) {
127133
switch (embeddingType) {
128134
case FLOAT -> floatResults.add(new AtomicArray<>(numberOfSubBatches));
129135
case BYTE -> byteResults.add(new AtomicArray<>(numberOfSubBatches));
136+
case SPARSE -> sparseResults.add(new AtomicArray<>(numberOfSubBatches));
130137
}
131138
chunkedInputs.add(chunks);
132139
}
@@ -217,6 +224,7 @@ public void onResponse(InferenceServiceResults inferenceServiceResults) {
217224
switch (embeddingType) {
218225
case FLOAT -> handleFloatResults(inferenceServiceResults);
219226
case BYTE -> handleByteResults(inferenceServiceResults);
227+
case SPARSE -> handleSparseResults(inferenceServiceResults);
220228
}
221229
}
222230

@@ -266,6 +274,29 @@ private void handleByteResults(InferenceServiceResults inferenceServiceResults)
266274
}
267275
}
268276

277+
private void handleSparseResults(InferenceServiceResults inferenceServiceResults) {
278+
if (inferenceServiceResults instanceof SparseEmbeddingResults sparseEmbeddings) {
279+
if (failIfNumRequestsDoNotMatch(sparseEmbeddings.embeddings().size())) {
280+
return;
281+
}
282+
283+
int start = 0;
284+
for (var pos : positions) {
285+
sparseResults.get(pos.inputIndex())
286+
.setOnce(pos.chunkIndex(), sparseEmbeddings.embeddings().subList(start, start + pos.embeddingCount()));
287+
start += pos.embeddingCount();
288+
}
289+
290+
if (resultCount.incrementAndGet() == totalNumberOfRequests) {
291+
sendResponse();
292+
}
293+
} else {
294+
onFailure(
295+
unexpectedResultTypeException(inferenceServiceResults.getWriteableName(), InferenceTextEmbeddingByteResults.NAME)
296+
);
297+
}
298+
}
299+
269300
private boolean failIfNumRequestsDoNotMatch(int numberOfResults) {
270301
int numberOfRequests = positions.stream().mapToInt(SubBatchPositionsAndCount::embeddingCount).sum();
271302
if (numberOfRequests != numberOfResults) {
@@ -319,6 +350,7 @@ private ChunkedInferenceServiceResults mergeResultsWithInputs(int resultIndex) {
319350
return switch (embeddingType) {
320351
case FLOAT -> mergeFloatResultsWithInputs(chunkedInputs.get(resultIndex), floatResults.get(resultIndex));
321352
case BYTE -> mergeByteResultsWithInputs(chunkedInputs.get(resultIndex), byteResults.get(resultIndex));
353+
case SPARSE -> mergeSparseResultsWithInputs(chunkedInputs.get(resultIndex), sparseResults.get(resultIndex));
322354
};
323355
}
324356

@@ -366,6 +398,26 @@ private InferenceChunkedTextEmbeddingByteResults mergeByteResultsWithInputs(
366398
return new InferenceChunkedTextEmbeddingByteResults(embeddingChunks, false);
367399
}
368400

401+
private InferenceChunkedSparseEmbeddingResults mergeSparseResultsWithInputs(
402+
List<String> chunks,
403+
AtomicArray<List<SparseEmbeddingResults.Embedding>> debatchedResults
404+
) {
405+
var all = new ArrayList<SparseEmbeddingResults.Embedding>();
406+
for (int i = 0; i < debatchedResults.length(); i++) {
407+
var subBatch = debatchedResults.get(i);
408+
all.addAll(subBatch);
409+
}
410+
411+
assert chunks.size() == all.size();
412+
413+
var embeddingChunks = new ArrayList<MlChunkedTextExpansionResults.ChunkedResult>();
414+
for (int i = 0; i < chunks.size(); i++) {
415+
embeddingChunks.add(new MlChunkedTextExpansionResults.ChunkedResult(chunks.get(i), all.get(i).tokens()));
416+
}
417+
418+
return new InferenceChunkedSparseEmbeddingResults(embeddingChunks);
419+
}
420+
369421
public record BatchRequest(List<SubBatch> subBatches) {
370422
public int size() {
371423
return subBatches.stream().mapToInt(SubBatch::size).sum();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,15 +248,14 @@ public static InferModelAction.Request buildInferenceRequest(
248248
InferenceConfigUpdate update,
249249
List<String> inputs,
250250
InputType inputType,
251-
TimeValue timeout,
252-
boolean chunk
251+
TimeValue timeout
253252
) {
254253
var request = InferModelAction.Request.forTextInput(id, update, inputs, true, timeout);
255254
request.setPrefixType(
256255
InputType.SEARCH == inputType ? TrainedModelPrefixStrings.PrefixType.SEARCH : TrainedModelPrefixStrings.PrefixType.INGEST
257256
);
258257
request.setHighPriority(InputType.SEARCH == inputType);
259-
request.setChunked(chunk);
258+
request.setChunked(false);
260259
return request;
261260
}
262261

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ public abstract ActionListener<CreateTrainedModelAssignmentAction.Response> getC
5858
ActionListener<Boolean> listener
5959
);
6060

61+
@Override
62+
public ElasticsearchInternalServiceSettings getServiceSettings() {
63+
return (ElasticsearchInternalServiceSettings) super.getServiceSettings();
64+
}
65+
6166
@Override
6267
public String toString() {
6368
return Strings.toString(this.getConfigurations());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 89 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,19 @@
2828
import org.elasticsearch.inference.TaskType;
2929
import org.elasticsearch.inference.UnparsedModel;
3030
import org.elasticsearch.rest.RestStatus;
31-
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
32-
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
33-
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
3431
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
3532
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
3633
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
3734
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
3835
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
3936
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
40-
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextEmbeddingFloatResults;
41-
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults;
37+
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
38+
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
39+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
4240
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
4341
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
4442
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate;
45-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate;
43+
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
4644
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
4745
import org.elasticsearch.xpack.inference.services.ServiceUtils;
4846

@@ -74,6 +72,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
7472
MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86
7573
);
7674

75+
public static final int EMBEDDING_MAX_BATCH_SIZE = 10;
7776
public static final String DEFAULT_ELSER_ID = ".elser-2";
7877

7978
private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
@@ -501,8 +500,7 @@ public void inferTextEmbedding(
501500
TextEmbeddingConfigUpdate.EMPTY_INSTANCE,
502501
inputs,
503502
inputType,
504-
timeout,
505-
false
503+
timeout
506504
);
507505

508506
ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
@@ -528,8 +526,7 @@ public void inferSparseEmbedding(
528526
TextExpansionConfigUpdate.EMPTY_UPDATE,
529527
inputs,
530528
inputType,
531-
timeout,
532-
false
529+
timeout
533530
);
534531

535532
ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
@@ -557,8 +554,7 @@ public void inferRerank(
557554
new TextSimilarityConfigUpdate(query),
558555
inputs,
559556
inputType,
560-
timeout,
561-
false
557+
timeout
562558
);
563559

564560
var modelSettings = (CustomElandRerankTaskSettings) model.getTaskSettings();
@@ -610,52 +606,80 @@ public void chunkedInfer(
610606

611607
if (model instanceof ElasticsearchInternalModel esModel) {
612608

613-
var configUpdate = chunkingOptions != null
614-
? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span())
615-
: new TokenizationConfigUpdate(null, null);
616-
617-
var request = buildInferenceRequest(
618-
model.getConfigurations().getInferenceEntityId(),
619-
configUpdate,
609+
var batchedRequests = new EmbeddingRequestChunker(
620610
input,
621-
inputType,
622-
timeout,
623-
true
624-
);
611+
EMBEDDING_MAX_BATCH_SIZE,
612+
embeddingTypeFromTaskTypeAndSettings(model.getTaskType(), esModel.internalServiceSettings)
613+
).batchRequestsWithListeners(listener);
614+
615+
for (var batch : batchedRequests) {
616+
var inferenceRequest = buildInferenceRequest(
617+
model.getConfigurations().getInferenceEntityId(),
618+
EmptyConfigUpdate.INSTANCE,
619+
batch.batch().inputs(),
620+
inputType,
621+
timeout
622+
);
625623

626-
ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
627-
(l, inferenceResult) -> l.onResponse(translateToChunkedResults(inferenceResult.getInferenceResults()))
628-
);
624+
ActionListener<InferModelAction.Response> mlResultsListener = batch.listener()
625+
.delegateFailureAndWrap(
626+
(l, inferenceResult) -> translateToChunkedResult(model.getTaskType(), inferenceResult.getInferenceResults(), l)
627+
);
629628

630-
var maybeDeployListener = mlResultsListener.delegateResponse(
631-
(l, exception) -> maybeStartDeployment(esModel, exception, request, mlResultsListener)
632-
);
629+
var maybeDeployListener = mlResultsListener.delegateResponse(
630+
(l, exception) -> maybeStartDeployment(esModel, exception, inferenceRequest, mlResultsListener)
631+
);
633632

634-
client.execute(InferModelAction.INSTANCE, request, maybeDeployListener);
633+
client.execute(InferModelAction.INSTANCE, inferenceRequest, maybeDeployListener);
634+
}
635635
} else {
636636
listener.onFailure(notElasticsearchModelException(model));
637637
}
638638
}
639639

640-
private static List<ChunkedInferenceServiceResults> translateToChunkedResults(List<InferenceResults> inferenceResults) {
641-
var translated = new ArrayList<ChunkedInferenceServiceResults>();
642-
643-
for (var inferenceResult : inferenceResults) {
644-
translated.add(translateToChunkedResult(inferenceResult));
645-
}
646-
647-
return translated;
648-
}
640+
private static void translateToChunkedResult(
641+
TaskType taskType,
642+
List<InferenceResults> inferenceResults,
643+
ActionListener<InferenceServiceResults> chunkPartListener
644+
) {
645+
if (taskType == TaskType.TEXT_EMBEDDING) {
646+
var translated = new ArrayList<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>();
649647

650-
private static ChunkedInferenceServiceResults translateToChunkedResult(InferenceResults inferenceResult) {
651-
if (inferenceResult instanceof MlChunkedTextEmbeddingFloatResults mlChunkedResult) {
652-
return InferenceChunkedTextEmbeddingFloatResults.ofMlResults(mlChunkedResult);
653-
} else if (inferenceResult instanceof MlChunkedTextExpansionResults mlChunkedResult) {
654-
return InferenceChunkedSparseEmbeddingResults.ofMlResult(mlChunkedResult);
655-
} else if (inferenceResult instanceof ErrorInferenceResults error) {
656-
return new ErrorChunkedInferenceResults(error.getException());
657-
} else {
658-
throw createInvalidChunkedResultException(MlChunkedTextEmbeddingFloatResults.NAME, inferenceResult.getWriteableName());
648+
for (var inferenceResult : inferenceResults) {
649+
if (inferenceResult instanceof MlTextEmbeddingResults mlTextEmbeddingResult) {
650+
translated.add(
651+
new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(mlTextEmbeddingResult.getInferenceAsFloat())
652+
);
653+
} else if (inferenceResult instanceof ErrorInferenceResults error) {
654+
chunkPartListener.onFailure(error.getException());
655+
return;
656+
} else {
657+
chunkPartListener.onFailure(
658+
createInvalidChunkedResultException(MlTextEmbeddingResults.NAME, inferenceResult.getWriteableName())
659+
);
660+
return;
661+
}
662+
}
663+
chunkPartListener.onResponse(new InferenceTextEmbeddingFloatResults(translated));
664+
} else { // sparse
665+
var translated = new ArrayList<SparseEmbeddingResults.Embedding>();
666+
667+
for (var inferenceResult : inferenceResults) {
668+
if (inferenceResult instanceof TextExpansionResults textExpansionResult) {
669+
translated.add(
670+
new SparseEmbeddingResults.Embedding(textExpansionResult.getWeightedTokens(), textExpansionResult.isTruncated())
671+
);
672+
} else if (inferenceResult instanceof ErrorInferenceResults error) {
673+
chunkPartListener.onFailure(error.getException());
674+
return;
675+
} else {
676+
chunkPartListener.onFailure(
677+
createInvalidChunkedResultException(TextExpansionResults.NAME, inferenceResult.getWriteableName())
678+
);
679+
return;
680+
}
681+
}
682+
chunkPartListener.onResponse(new SparseEmbeddingResults(translated));
659683
}
660684
}
661685

@@ -738,4 +762,21 @@ public List<UnparsedModel> defaultConfigs() {
738762
protected boolean isDefaultId(String inferenceId) {
739763
return DEFAULT_ELSER_ID.equals(inferenceId);
740764
}
765+
766+
static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings(
767+
TaskType taskType,
768+
ElasticsearchInternalServiceSettings serviceSettings
769+
) {
770+
return switch (taskType) {
771+
case SPARSE_EMBEDDING -> EmbeddingRequestChunker.EmbeddingType.SPARSE;
772+
case TEXT_EMBEDDING -> serviceSettings.elementType() == null
773+
? EmbeddingRequestChunker.EmbeddingType.FLOAT
774+
: EmbeddingRequestChunker.EmbeddingType.fromDenseVectorElementType(serviceSettings.elementType());
775+
default -> throw new ElasticsearchStatusException(
776+
"Chunking is not supported for task type [{}]",
777+
RestStatus.BAD_REQUEST,
778+
taskType
779+
);
780+
};
781+
}
741782
}

0 commit comments

Comments
 (0)