diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 2df2f1e62f89a..da071442d6c1b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -37,16 +37,16 @@ * a single large input that has been chunked may spread over * multiple batches. * - * The final aspect it to gather the responses from the batch + * The final aspect is to gather the responses from the batch * processing and map the results back to the original element * in the input list. */ public class EmbeddingRequestChunker> { // Visible for testing - record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List inputs) { + record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input) { public String chunkText() { - return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end()); + return input.substring(chunk.start(), chunk.end()); } } @@ -60,7 +60,7 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings); + String inputString = inputs.get(inputIndex).input(); + List chunks = chunker.chunk(inputString, chunkingSettings); int resultCount = Math.min(chunks.size(), MAX_CHUNKS); resultEmbeddings.add(new AtomicReferenceArray<>(resultCount)); resultOffsetStarts.add(new ArrayList<>(resultCount)); @@ -129,7 +130,7 @@ public EmbeddingRequestChunker( } else { resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end()); } - allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs)); + allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java index 55cdb7207e25d..f9fd3a2011ee0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java @@ -8,63 +8,47 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ChunkInferenceInput; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import java.util.List; import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; -import java.util.stream.Collectors; public class EmbeddingsInput extends InferenceInputs { - - public static EmbeddingsInput of(InferenceInputs inferenceInputs) { - if (inferenceInputs instanceof EmbeddingsInput == false) { - throw createUnsupportedTypeException(inferenceInputs, EmbeddingsInput.class); - } - - return (EmbeddingsInput) inferenceInputs; - } - - private final Supplier> listSupplier; + private final Supplier> inputListSupplier; private final InputType inputType; + private final AtomicBoolean supplierInvoked = new AtomicBoolean(); - public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { - super(false); - this.listSupplier = Objects.requireNonNull(inputSupplier); - this.inputType = inputType; + public EmbeddingsInput(List input, @Nullable InputType inputType) { + this(() -> input, inputType, false); } - public EmbeddingsInput(List input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) { - this(input, chunkingSettings, inputType, false); + public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { + this(() -> input, inputType, stream); } - public EmbeddingsInput(List input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) { - this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream); + public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { + this(inputSupplier, inputType, false); } - public EmbeddingsInput(List input, @Nullable InputType inputType) { - this(input, inputType, false); - } - - public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { + private EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType, boolean stream) { super(stream); - Objects.requireNonNull(input); - this.listSupplier = () -> input; + this.inputListSupplier = Objects.requireNonNull(inputSupplier); this.inputType = inputType; } - public List getInputs() { - return this.listSupplier.get(); - } - - public static EmbeddingsInput fromStrings(List input, @Nullable InputType inputType) { - return new EmbeddingsInput(input, null, inputType); - } - - public List getStringInputs() { - return getInputs().stream().map(ChunkInferenceInput::input).collect(Collectors.toList()); + /** + * Calling this method twice will result in the {@link #inputListSupplier} being invoked twice. In the case where the supplier simply + * returns the list passed into the constructor, this is not a problem, but in the case where a supplier that will chunk the input + * Strings when invoked is passed into the constructor, this will result in multiple copies of the input Strings being created. Calling + * this method twice in a non-production environment will cause an {@link AssertionError} to be thrown. + * + * @return a list of String embedding inputs + */ + public List getInputs() { + assert supplierInvoked.compareAndSet(false, true) : "EmbeddingsInput supplier invoked twice"; + return inputListSupplier.get(); } public InputType getInputType() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java index c39387d647f77..4a485f87858aa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java @@ -52,7 +52,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs(); + var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); var truncatedInput = truncate(docsInput, maxInputTokens); var request = requestCreator.apply(truncatedInput); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 657834e6831ff..0147c62823f0d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -101,7 +101,7 @@ private static InferenceInputs createInput( if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - yield new EmbeddingsInput(input, null, inputType, stream); + yield new EmbeddingsInput(input, inputType, stream); } default -> throw new ElasticsearchStatusException( Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()), @@ -140,7 +140,7 @@ public void chunkedInfer( } // a non-null query is not supported and is dropped by all providers - doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, inputType, timeout, listener); + doChunkedInfer(model, input, taskSettings, inputType, timeout, listener); } protected abstract void doInfer( @@ -164,7 +164,7 @@ protected abstract void doUnifiedCompletionInfer( protected abstract void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java index f11cd41b25aa0..8a77f65592226 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java @@ -71,8 +71,9 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 7897317736c72..ff6482a795f77 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -321,7 +322,7 @@ protected void validateRerankParameters(Boolean returnDocuments, Integer topN, V @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -336,14 +337,14 @@ protected void doChunkedInfer( var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, alibabaCloudSearchModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java index acce3b9a1d6ea..20ff8ce58b550 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java @@ -71,8 +71,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java index 06910611e0a96..387d8b65f40d6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java @@ -56,8 +56,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var serviceSettings = embeddingsModel.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 591607953ea1a..04122127b72d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -136,7 +137,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -147,14 +148,14 @@ protected void doChunkedInfer( var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, maxBatchSize, baseAmazonBedrockModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 791518ccc9168..8be92ce68abb7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -26,7 +27,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; @@ -222,7 +222,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java index abb9b26a80b0c..afb268ab499a9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java @@ -50,8 +50,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 04883f23b947f..34ebac37a2a55 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -123,7 +124,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -133,14 +134,14 @@ protected void doChunkedInfer( var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, baseAzureAiStudioModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java index e98bf731210d7..db38b3fb0def3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java @@ -63,8 +63,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index e9ff97c1ba725..66fb753d76b7d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -272,7 +273,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -286,14 +287,14 @@ protected void doChunkedInfer( var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, azureOpenAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = azureOpenAiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index c2f1221763165..6c87ee5b86d68 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -274,7 +275,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -289,14 +290,14 @@ protected void doChunkedInfer( var actionCreator = new CohereActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, cohereModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = cohereModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java index 777ddc348bda6..121f0e1e80a96 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java @@ -81,8 +81,8 @@ public ExecutableAction create(CohereEmbeddingsModel model, Map : overriddenModel.getTaskSettings().getInputType(); return switch (overriddenModel.getServiceSettings().getCommonSettings().apiVersion()) { - case V1 -> new CohereV1EmbeddingsRequest(inferenceInputs.getStringInputs(), requestInputType, overriddenModel); - case V2 -> new CohereV2EmbeddingsRequest(inferenceInputs.getStringInputs(), requestInputType, overriddenModel); + case V1 -> new CohereV1EmbeddingsRequest(inferenceInputs.getInputs(), requestInputType, overriddenModel); + case V2 -> new CohereV2EmbeddingsRequest(inferenceInputs.getInputs(), requestInputType, overriddenModel); }; }; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java index 6e2400c33d6a2..1c08fdb9f4285 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java @@ -75,7 +75,7 @@ public void execute( requestParameters = CompletionParameters.of(chatInputs); } else if (inferenceInputs instanceof EmbeddingsInput) { requestParameters = EmbeddingParameters.of( - EmbeddingsInput.of(inferenceInputs), + inferenceInputs.castTo(EmbeddingsInput.class), model.getServiceSettings().getInputTypeTranslator() ); } else { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 4a4166cf65ed3..16acfaa1af430 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -136,7 +137,7 @@ private static RequestParameters createParameters(CustomModel model) { case RERANK -> RerankParameters.of(new QueryAndDocsInputs("test query", List.of("test input"))); case COMPLETION -> CompletionParameters.of(new ChatCompletionInput(List.of("test input"))); case TEXT_EMBEDDING, SPARSE_EMBEDDING -> EmbeddingParameters.of( - new EmbeddingsInput(List.of("test input"), null, null), + new EmbeddingsInput(List.of("test input"), null), model.getServiceSettings().getInputTypeTranslator() ); default -> throw new IllegalStateException( @@ -280,7 +281,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -298,14 +299,14 @@ protected void doChunkedInfer( var manager = CustomRequestManager.of(overriddenModel, getServiceComponents().threadPool()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, customModel.getServiceSettings().getBatchSize(), customModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java index 71eb0ed8e098d..ef91045eb1dab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java @@ -28,7 +28,7 @@ public static EmbeddingParameters of(EmbeddingsInput embeddingsInput, InputTypeT private final InputTypeTranslator translator; private EmbeddingParameters(EmbeddingsInput embeddingsInput, InputTypeTranslator translator) { - super(embeddingsInput.getStringInputs()); + super(embeddingsInput.getInputs()); this.inputType = embeddingsInput.getInputType(); this.translator = translator; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index 56719199e094f..22638ed7463a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -25,7 +26,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; @@ -106,7 +106,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 4f8d0d01861cd..4c14b507ecb2e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; @@ -351,7 +352,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -361,14 +362,14 @@ protected void doChunkedInfer( var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE, denseTextEmbeddingsModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = denseTextEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } return; @@ -378,14 +379,14 @@ protected void doChunkedInfer( var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE, model.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = sparseTextEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } return; @@ -618,7 +619,7 @@ public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() { private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs(); + var inputsAsList = inputs.castTo(EmbeddingsInput.class).getInputs(); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java index ea82eb228dbc8..f34d538c413dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java @@ -69,8 +69,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java index 8b987cd53bc81..749a3277929f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java @@ -97,7 +97,7 @@ public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel m DENSE_TEXT_EMBEDDINGS_HANDLER, (embeddingsInput) -> new ElasticInferenceServiceDenseTextEmbeddingsRequest( model, - embeddingsInput.getStringInputs(), + embeddingsInput.getInputs(), traceContext, extractRequestMetadataFromThreadContext(threadPool.getThreadContext()), embeddingsInput.getInputType() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java index 13e54b9e3e17b..e65bfda857c1b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java @@ -56,8 +56,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 9841ea64370c3..8665e744b76e1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -340,7 +341,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -349,19 +350,13 @@ protected void doChunkedInfer( GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model; List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, googleAiStudioModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - doInfer( - model, - EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), - taskSettings, - timeout, - request.listener() - ); + doInfer(model, new EmbeddingsInput(request.batch().inputs(), inputType), taskSettings, timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java index 2dc60ef114459..90d7a0b1b0a11 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java @@ -64,8 +64,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 3b59e999125e5..6a725a0395f7b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -264,7 +265,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -274,14 +275,14 @@ protected void doChunkedInfer( var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, googleVertexAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = googleVertexAiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java index 7bb140e91ec5d..833f9bc6b347b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java @@ -62,8 +62,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); - var truncatedInput = truncate(docsInput, model.getTokenLimit()); + List inputs = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); + var truncatedInput = truncate(inputs, model.getTokenLimit()); var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index d10fb77290c6b..4c110505706c7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -140,7 +141,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -155,14 +156,14 @@ protected void doChunkedInfer( var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, huggingFaceModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = huggingFaceModel.accept(actionCreator); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index e61995aac91f3..934ace507219f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -94,7 +94,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -105,22 +105,31 @@ protected void doChunkedInfer( ); // TODO chunking sparse embeddings not implemented - doInfer(model, inputs, taskSettings, timeout, inferListener); + doInfer( + model, + new EmbeddingsInput(inputs.stream().map(ChunkInferenceInput::input).toList(), inputType), + taskSettings, + timeout, + inferListener + ); } - private static List translateToChunkedResults(EmbeddingsInput inputs, InferenceServiceResults inferenceResults) { + private static List translateToChunkedResults( + List inputs, + InferenceServiceResults inferenceResults + ) { if (inferenceResults instanceof TextEmbeddingFloatResults textEmbeddingResults) { - validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs.getInputs()), textEmbeddingResults.embeddings().size()); + validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs), textEmbeddingResults.embeddings().size()); - var results = new ArrayList(inputs.getInputs().size()); + var results = new ArrayList(inputs.size()); - for (int i = 0; i < inputs.getInputs().size(); i++) { + for (int i = 0; i < inputs.size(); i++) { results.add( new ChunkedInferenceEmbedding( List.of( new EmbeddingResults.Chunk( textEmbeddingResults.embeddings().get(i), - new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).input().length()) + new ChunkedInference.TextOffset(0, inputs.get(i).input().length()) ) ) ) @@ -128,7 +137,7 @@ private static List translateToChunkedResults(EmbeddingsInput } return results; } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = ChunkInferenceInput.inputs(EmbeddingsInput.of(inputs).getInputs()); + var inputsAsList = ChunkInferenceInput.inputs(inputs); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java index b7c679d3cda54..ebdb20d2bdcc5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java @@ -55,7 +55,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); + List docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); execute( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 7dfb0002bb062..e502d830c6ee4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -297,7 +298,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, @@ -306,13 +307,16 @@ protected void doChunkedInfer( IbmWatsonxModel ibmWatsonxModel = (IbmWatsonxModel) model; var batchedRequests = new EmbeddingRequestChunker<>( - input.getInputs(), + input, EMBEDDING_MAX_BATCH_SIZE, ibmWatsonxModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); + + var actionCreator = getActionCreator(getSender(), getServiceComponents()); + for (var request : batchedRequests) { - var action = ibmWatsonxModel.accept(getActionCreator(getSender(), getServiceComponents()), taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + var action = ibmWatsonxModel.accept(actionCreator, taskSettings); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java index 083690a894c00..3e3918acb78dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java @@ -52,8 +52,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); JinaAIEmbeddingsRequest request = new JinaAIEmbeddingsRequest(docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index c2e88cb6cdc7c..8c5ac8f3b2a66 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -265,7 +266,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -280,14 +281,14 @@ protected void doChunkedInfer( var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, jinaaiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = jinaaiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java index 391a549df924a..ea31435780b96 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java @@ -61,7 +61,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); + List docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index b11feb117d761..2a8d186d80e99 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -146,7 +147,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -156,14 +157,14 @@ protected void doChunkedInfer( if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) { List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, MistralConstants.MAX_BATCH_SIZE, mistralEmbeddingsModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index edff1dfc08cba..577c067ba730f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -325,7 +326,7 @@ public void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -340,14 +341,14 @@ protected void doChunkedInfer( var actionCreator = new OpenAiActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, openAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = openAiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 0ffec057dc2b4..f05dd256cfd50 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -285,7 +286,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -300,14 +301,14 @@ protected void doChunkedInfer( var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, getBatchSize(voyageaiModel), voyageaiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = voyageaiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java index 03753835177cb..5bf9bd66def2f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java @@ -57,7 +57,7 @@ public ExecutableAction create(VoyageAIEmbeddingsModel model, Map new VoyageAIEmbeddingsRequest( - embeddingsInput.getStringInputs(), + embeddingsInput.getInputs(), embeddingsInput.getInputType(), overriddenModel ), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java index 0bfa640b0cded..2682f74d0ebf9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; @@ -65,8 +64,6 @@ public void testOneInputIsValid() { public void testMoreThanOneInput() { var badInput = mock(EmbeddingsInput.class); - var input = List.of(new ChunkInferenceInput("one"), new ChunkInferenceInput("two")); - when(badInput.getInputs()).thenReturn(input); when(badInput.isSingleInput()).thenReturn(false); var actualException = new AtomicReference(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java new file mode 100644 index 0000000000000..d6ba10b1932dc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.is; + +public class EmbeddingsInputTests extends ESTestCase { + public void testCallingGetInputs_invokesSupplier() { + AtomicBoolean invoked = new AtomicBoolean(); + final List list = List.of("input1", "input2"); + Supplier> supplier = () -> { + invoked.set(true); + return list; + }; + EmbeddingsInput input = new EmbeddingsInput(supplier, null); + // Ensure we don't invoke the supplier until we call getInputs() + assertThat(invoked.get(), is(false)); + + assertThat(input.getInputs(), is(list)); + assertThat(invoked.get(), is(true)); + } + + public void testCallingGetInputsTwice_throws() { + Supplier> supplier = () -> List.of("input"); + EmbeddingsInput input = new EmbeddingsInput(supplier, null); + input.getInputs(); + var exception = expectThrows(AssertionError.class, input::getInputs); + assertThat(exception.getMessage(), is("EmbeddingsInput supplier invoked twice")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index d8c8095879b55..ee633f87cd3b4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -143,7 +142,7 @@ public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception PlainActionFuture listener = new PlainActionFuture<>(); sender.send( OpenAiEmbeddingsRequestManagerTests.makeCreator(getUrl(webServer), null, "key", "model", null, threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null), + new EmbeddingsInput(List.of("abc"), null), null, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java index e91b0b3451a77..9e7215ecc0a94 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java @@ -17,7 +17,7 @@ public class InferenceInputsTests extends ESTestCase { public void testCastToSucceeds() { - InferenceInputs inputs = new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull(), false); + InferenceInputs inputs = new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()); assertThat(inputs.castTo(EmbeddingsInput.class), Matchers.instanceOf(EmbeddingsInput.class)); var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null); @@ -29,7 +29,7 @@ public void testCastToSucceeds() { } public void testCastToFails() { - InferenceInputs inputs = new EmbeddingsInput(List.of(), null, false); + InferenceInputs inputs = new EmbeddingsInput(List.of(), null); var exception = expectThrows(IllegalArgumentException.class, () -> inputs.castTo(QueryAndDocsInputs.class)); assertThat( exception.getMessage(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java index d40ee517a1c51..a232b7724ca98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.Scheduler; @@ -62,7 +61,7 @@ public void testExecuting_DoesNotCallOnFailureForTimeout_AfterIllegalArgumentExc var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), mockThreadPool, listener @@ -82,7 +81,7 @@ public void testRequest_ReturnsTimeoutException() { PlainActionFuture listener = new PlainActionFuture<>(); var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -106,7 +105,7 @@ public void testRequest_DoesNotCallOnFailureTwiceWhenTimingOut() throws Exceptio var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -135,7 +134,7 @@ public void testRequest_DoesNotCallOnResponseAfterTimingOut() throws Exception { var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -162,7 +161,7 @@ public void testRequest_DoesNotCallOnFailureForTimeout_AfterAlreadyCallingOnResp var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), mockThreadPool, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 5d7a6a149f941..659b935aff8d8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -20,7 +21,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -131,7 +131,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java index 97bac8582c1fc..3ac706df819b4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java @@ -122,7 +122,7 @@ public void testExecute_ThrowsIllegalArgumentException_WhenInputIsNotChatComplet PlainActionFuture listener = new PlainActionFuture<>(); assertThrows(IllegalArgumentException.class, () -> { action.execute( - new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, InputType.INGEST), + new EmbeddingsInput(List.of(randomAlphaOfLength(10)), InputType.INGEST), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java index 8b55d5b78f397..145a2e6078360 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java @@ -74,7 +74,7 @@ public void testEmbeddingsRequestAction_Titan() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -112,7 +112,7 @@ public void testEmbeddingsRequestAction_Cohere() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); @@ -145,7 +145,7 @@ public void testEmbeddingsRequestAction_HandlesException() throws IOException { ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(sender.sendCount(), is(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java index 797d50878a0b7..aa19e9ae07d48 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java @@ -12,7 +12,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; @@ -83,10 +82,10 @@ public void send( ActionListener listener ) { sendCounter++; - if (inferenceInputs instanceof EmbeddingsInput docsInput) { - inputs.add(ChunkInferenceInput.inputs(docsInput.getInputs())); - if (docsInput.getInputType() != null) { - inputTypes.add(docsInput.getInputType()); + if (inferenceInputs instanceof EmbeddingsInput embeddingsInput) { + inputs.add(embeddingsInput.getInputs()); + if (embeddingsInput.getInputType() != null) { + inputTypes.add(embeddingsInput.getInputType()); } } else if (inferenceInputs instanceof ChatCompletionInput chatCompletionInput) { inputs.add(chatCompletionInput.getInputs()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java index 8a6a6a9c5eddf..4fbbfdd40ee30 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -112,7 +111,7 @@ public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws threadPool, new TimeValue(30, TimeUnit.SECONDS) ); - sender.send(requestManager, new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null), null, listener); + sender.send(requestManager, new EmbeddingsInput(List.of("abc"), null), null, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.456F, 0.678F, 0.789F })))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java index 9896286f503f3..29658db5a3d9d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java @@ -115,7 +115,7 @@ public void testEmbeddingsRequestAction() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java index 5f3bbd5af0a16..5287907a2ce76 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java @@ -119,7 +119,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -170,7 +170,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOExcepti var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -222,7 +222,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var failureCauseMessage = "Required [data]"; var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -296,7 +296,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -373,7 +373,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -433,11 +433,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute( - new EmbeddingsInput(List.of("super long input"), null, inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("super long input"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java index 7d59ea225ab22..f9c85af00d4da 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -117,11 +116,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -149,11 +144,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -174,11 +165,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -199,11 +186,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -218,11 +201,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java index 6438a328f9fcf..9dfd05f13cb6e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java @@ -119,7 +119,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java index 05d69bae4903e..2d87dc3686040 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java @@ -124,7 +124,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); InputType inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -207,7 +207,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I ); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -262,7 +262,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -282,7 +282,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -302,7 +302,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -316,7 +316,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -330,7 +330,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java index 5e6ee1032ad5f..6dc33935ac617 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java @@ -77,7 +77,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() { var listener = new PlainActionFuture(); var manager = CustomRequestManager.of(model, threadPool); - manager.execute(new EmbeddingsInput(List.of("abc", "123"), null, null), mock(RequestSender.class), () -> false, listener); + manager.execute(new EmbeddingsInput(List.of("abc", "123"), null), mock(RequestSender.class), () -> false, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TimeValue.timeValueSeconds(30))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index 106b55159b1dc..9c0ab64dbd996 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -82,10 +82,7 @@ public void testCreateRequest() throws IOException { ); var request = new CustomRequest( - EmbeddingParameters.of( - new EmbeddingsInput(List.of("abc", "123"), null, null), - model.getServiceSettings().getInputTypeTranslator() - ), + EmbeddingParameters.of(new EmbeddingsInput(List.of("abc", "123"), null), model.getServiceSettings().getInputTypeTranslator()), model ); var httpRequest = request.createHttpRequest(); @@ -146,7 +143,7 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() throws IOEx var request = new CustomRequest( EmbeddingParameters.of( - new EmbeddingsInput(List.of("abc", "123"), null, InputType.INGEST), + new EmbeddingsInput(List.of("abc", "123"), InputType.INGEST), model.getServiceSettings().getInputTypeTranslator() ), model @@ -207,7 +204,7 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws var request = new CustomRequest( EmbeddingParameters.of( - new EmbeddingsInput(List.of("abc", "123"), null, InputType.SEARCH), + new EmbeddingsInput(List.of("abc", "123"), InputType.SEARCH), model.getServiceSettings().getInputTypeTranslator() ), model diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java index 5231802ef4b92..d6dc03c9daf75 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java @@ -21,7 +21,7 @@ public class EmbeddingParametersTests extends ESTestCase { public void testTaskTypeParameters_UsesDefaultValue() { var parameters = EmbeddingParameters.of( - new EmbeddingsInput(List.of("input"), null, InputType.INGEST), + new EmbeddingsInput(List.of("input"), InputType.INGEST), new InputTypeTranslator(Map.of(), "default") ); @@ -30,7 +30,7 @@ public void testTaskTypeParameters_UsesDefaultValue() { public void testTaskTypeParameters_UsesMappedValue() { var parameters = EmbeddingParameters.of( - new EmbeddingsInput(List.of("input"), null, InputType.INGEST), + new EmbeddingsInput(List.of("input"), InputType.INGEST), new InputTypeTranslator(Map.of(InputType.INGEST, "ingest_value"), "default") ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java index 608fdb4d314c3..2eb61b50574ee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java @@ -64,7 +64,7 @@ public void testFromTextEmbeddingResponse() throws IOException { new TextEmbeddingResponseParser("$.result.embeddings[*].embedding") ); var request = new CustomRequest( - EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()), + EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null), model.getServiceSettings().getInputTypeTranslator()), model ); InferenceServiceResults results = CustomResponseEntity.fromResponse( @@ -115,7 +115,7 @@ public void testFromSparseEmbeddingResponse() throws IOException { ) ); var request = new CustomRequest( - EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()), + EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null), model.getServiceSettings().getInputTypeTranslator()), model ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index c8701b47a20b5..ebcc9e6db20e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -102,7 +102,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -163,7 +163,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -291,7 +291,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world", "second text"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world", "second text"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -347,7 +347,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_W PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("search query"), null, InputType.SEARCH), + new EmbeddingsInput(List.of("search query"), InputType.SEARCH), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -402,7 +402,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForDenseTextEmbeddingsAction PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -443,7 +443,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_E var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of(), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -492,7 +492,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -556,7 +556,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java index 2668f7f8f7c27..f2077998d1797 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java @@ -108,7 +108,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of(input), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of(input), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -163,7 +163,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -187,7 +187,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -205,7 +205,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java index 5acd78930637b..8c6e5d31f59c4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java @@ -75,7 +75,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -99,7 +99,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -117,7 +117,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java index 760858b5a1261..0e575ed045711 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java @@ -71,7 +71,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -91,7 +91,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -105,7 +105,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index b14cfcd14ec43..62d7099b327d9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -101,7 +101,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -172,7 +172,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -224,7 +224,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -286,7 +286,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -409,7 +409,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -471,7 +471,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("123456"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("123456"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java index cfa0f0bb2198b..f3bedf04e056f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java @@ -63,7 +63,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderThrows() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -87,7 +87,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -108,7 +108,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java index 9376e4da76261..d9f3ed0c394db 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java @@ -114,7 +114,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(input), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(input), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -144,7 +144,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -173,7 +173,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -197,7 +197,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java index d7c72cf98e267..102cdbec77d74 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java @@ -109,7 +109,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -166,7 +166,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -222,7 +222,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -285,7 +285,7 @@ public void testCreate_OpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() th PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -625,7 +625,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -712,7 +712,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -784,7 +784,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("super long input"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("super long input"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java index 17c08dee34e5c..4a1609c7a27df 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; @@ -115,7 +114,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -155,7 +154,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -179,7 +178,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -203,7 +202,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -221,7 +220,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -239,7 +238,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java index 7e89b5c3497c9..2f51ec22c791c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -111,11 +110,7 @@ public void testCreate_VoyageAIEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java index a4310412514cd..ff3f23b2d9027 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -121,11 +120,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -222,11 +217,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -323,11 +314,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForBinaryResponseType() throws PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -396,7 +383,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -420,7 +407,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -438,7 +425,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -461,7 +448,7 @@ private ExecutableAction createAction( threadPool, model, EMBEDDINGS_HANDLER, - (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getStringInputs(), embeddingsInput.getInputType(), model), + (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getInputs(), embeddingsInput.getInputType(), model), EmbeddingsInput.class );