diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java index 1e188d0f7bf5b..55cdb7207e25d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java @@ -30,10 +30,6 @@ public static EmbeddingsInput of(InferenceInputs inferenceInputs) { private final Supplier> listSupplier; private final InputType inputType; - public EmbeddingsInput(List input, @Nullable InputType inputType) { - this(input, inputType, false); - } - public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { super(false); this.listSupplier = Objects.requireNonNull(inputSupplier); @@ -41,7 +37,15 @@ public EmbeddingsInput(Supplier> inputSupplier, @Nulla } public EmbeddingsInput(List input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) { - this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).collect(Collectors.toList()), inputType, false); + this(input, chunkingSettings, inputType, false); + } + + public EmbeddingsInput(List input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) { + this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream); + } + + public EmbeddingsInput(List input, @Nullable InputType inputType) { + this(input, inputType, false); } public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index ff8ae6fd5aac3..657834e6831ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -71,31 +71,29 @@ public void infer( ActionListener listener ) { init(); - var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList(); - var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream); + var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream); doInfer(model, inferenceInput, taskSettings, timeout, listener); } private static InferenceInputs createInput( SenderService service, Model model, - List input, + List input, InputType inputType, @Nullable String query, @Nullable Boolean returnDocuments, @Nullable Integer topN, boolean stream ) { - List textInput = ChunkInferenceInput.inputs(input); return switch (model.getTaskType()) { - case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(textInput, stream); + case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream); case RERANK -> { ValidationException validationException = new ValidationException(); service.validateRerankParameters(returnDocuments, topN, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - yield new QueryAndDocsInputs(query, textInput, returnDocuments, topN, stream); + yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream); } case TEXT_EMBEDDING, SPARSE_EMBEDDING -> { ValidationException validationException = new ValidationException(); @@ -103,7 +101,7 @@ private static InferenceInputs createInput( if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - yield new EmbeddingsInput(input, inputType, stream); + yield new EmbeddingsInput(input, null, inputType, stream); } default -> throw new ElasticsearchStatusException( Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),