Skip to content

Commit 8d2c4c3

Browse files
authored
[ML] Do not convert input Strings to ChunkInferenceInput unless necessary (#134945)
The SenderService.infer() method was converting the input variable from a List<String> into a List<ChunkInferenceInput>, but then when that list was passed into SenderService.createInput() it was immediately converted back into a List<String>. To avoid unnecessary work, allow the EmbeddingsInput constructor to convert the list if necessary.
1 parent 106b301 commit 8d2c4c3

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,22 @@ public static EmbeddingsInput of(InferenceInputs inferenceInputs) {
3030
private final Supplier<List<ChunkInferenceInput>> listSupplier;
3131
private final InputType inputType;
3232

33-
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
34-
this(input, inputType, false);
35-
}
36-
3733
public EmbeddingsInput(Supplier<List<ChunkInferenceInput>> inputSupplier, @Nullable InputType inputType) {
3834
super(false);
3935
this.listSupplier = Objects.requireNonNull(inputSupplier);
4036
this.inputType = inputType;
4137
}
4238

4339
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) {
44-
this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).collect(Collectors.toList()), inputType, false);
40+
this(input, chunkingSettings, inputType, false);
41+
}
42+
43+
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) {
44+
this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream);
45+
}
46+
47+
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
48+
this(input, inputType, false);
4549
}
4650

4751
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType, boolean stream) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,39 +75,37 @@ public void infer(
7575
) {
7676
timeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, clusterService);
7777
init();
78-
var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList();
79-
var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream);
78+
var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
8079
doInfer(model, inferenceInput, taskSettings, timeout, listener);
8180
}
8281

8382
private static InferenceInputs createInput(
8483
SenderService service,
8584
Model model,
86-
List<ChunkInferenceInput> input,
85+
List<String> input,
8786
InputType inputType,
8887
@Nullable String query,
8988
@Nullable Boolean returnDocuments,
9089
@Nullable Integer topN,
9190
boolean stream
9291
) {
93-
List<String> textInput = ChunkInferenceInput.inputs(input);
9492
return switch (model.getTaskType()) {
95-
case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(textInput, stream);
93+
case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
9694
case RERANK -> {
9795
ValidationException validationException = new ValidationException();
9896
service.validateRerankParameters(returnDocuments, topN, validationException);
9997
if (validationException.validationErrors().isEmpty() == false) {
10098
throw validationException;
10199
}
102-
yield new QueryAndDocsInputs(query, textInput, returnDocuments, topN, stream);
100+
yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream);
103101
}
104102
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> {
105103
ValidationException validationException = new ValidationException();
106104
service.validateInputType(inputType, model, validationException);
107105
if (validationException.validationErrors().isEmpty() == false) {
108106
throw validationException;
109107
}
110-
yield new EmbeddingsInput(input, inputType, stream);
108+
yield new EmbeddingsInput(input, null, inputType, stream);
111109
}
112110
default -> throw new ElasticsearchStatusException(
113111
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),

0 commit comments

Comments
 (0)