Skip to content

Commit 95c31ad

Browse files
authored
[ML] Do not convert input Strings to ChunkInferenceInput unless necessary (#134945) (#135291)
The SenderService.infer() method was converting the input variable from a List<String> into a List<ChunkInferenceInput>, but then when that list was passed into SenderService.createInput() it was immediately converted back into a List<String>. To avoid unnecessary work, allow the EmbeddingsInput constructor to convert the list if necessary. (cherry picked from commit 8d2c4c3)
1 parent 4fc9fa2 commit 95c31ad

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,22 @@ public static EmbeddingsInput of(InferenceInputs inferenceInputs) {
3030
private final Supplier<List<ChunkInferenceInput>> listSupplier;
3131
private final InputType inputType;
3232

33-
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
34-
this(input, inputType, false);
35-
}
36-
3733
public EmbeddingsInput(Supplier<List<ChunkInferenceInput>> inputSupplier, @Nullable InputType inputType) {
3834
super(false);
3935
this.listSupplier = Objects.requireNonNull(inputSupplier);
4036
this.inputType = inputType;
4137
}
4238

4339
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) {
44-
this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).collect(Collectors.toList()), inputType, false);
40+
this(input, chunkingSettings, inputType, false);
41+
}
42+
43+
public EmbeddingsInput(List<String> input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) {
44+
this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream);
45+
}
46+
47+
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType) {
48+
this(input, inputType, false);
4549
}
4650

4751
public EmbeddingsInput(List<ChunkInferenceInput> input, @Nullable InputType inputType, boolean stream) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,39 +71,37 @@ public void infer(
7171
ActionListener<InferenceServiceResults> listener
7272
) {
7373
init();
74-
var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList();
75-
var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream);
74+
var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
7675
doInfer(model, inferenceInput, taskSettings, timeout, listener);
7776
}
7877

7978
private static InferenceInputs createInput(
8079
SenderService service,
8180
Model model,
82-
List<ChunkInferenceInput> input,
81+
List<String> input,
8382
InputType inputType,
8483
@Nullable String query,
8584
@Nullable Boolean returnDocuments,
8685
@Nullable Integer topN,
8786
boolean stream
8887
) {
89-
List<String> textInput = ChunkInferenceInput.inputs(input);
9088
return switch (model.getTaskType()) {
91-
case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(textInput, stream);
89+
case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
9290
case RERANK -> {
9391
ValidationException validationException = new ValidationException();
9492
service.validateRerankParameters(returnDocuments, topN, validationException);
9593
if (validationException.validationErrors().isEmpty() == false) {
9694
throw validationException;
9795
}
98-
yield new QueryAndDocsInputs(query, textInput, returnDocuments, topN, stream);
96+
yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream);
9997
}
10098
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> {
10199
ValidationException validationException = new ValidationException();
102100
service.validateInputType(inputType, model, validationException);
103101
if (validationException.validationErrors().isEmpty() == false) {
104102
throw validationException;
105103
}
106-
yield new EmbeddingsInput(input, inputType, stream);
104+
yield new EmbeddingsInput(input, null, inputType, stream);
107105
}
108106
default -> throw new ElasticsearchStatusException(
109107
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),

0 commit comments

Comments
 (0)