diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 2df2f1e62f89a..da071442d6c1b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -37,16 +37,16 @@ * a single large input that has been chunked may spread over * multiple batches. * - * The final aspect it to gather the responses from the batch + * The final aspect is to gather the responses from the batch * processing and map the results back to the original element * in the input list. */ public class EmbeddingRequestChunker> { // Visible for testing - record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List inputs) { + record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input) { public String chunkText() { - return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end()); + return input.substring(chunk.start(), chunk.end()); } } @@ -60,7 +60,7 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings); + String inputString = inputs.get(inputIndex).input(); + List chunks = chunker.chunk(inputString, chunkingSettings); int resultCount = Math.min(chunks.size(), MAX_CHUNKS); resultEmbeddings.add(new AtomicReferenceArray<>(resultCount)); resultOffsetStarts.add(new ArrayList<>(resultCount)); @@ -129,7 +130,7 @@ public EmbeddingRequestChunker( } else { resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end()); } - allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs)); + allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java index 55cdb7207e25d..f9fd3a2011ee0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java @@ -8,63 +8,47 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ChunkInferenceInput; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import java.util.List; import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; -import java.util.stream.Collectors; public class EmbeddingsInput extends InferenceInputs { - - public static EmbeddingsInput of(InferenceInputs inferenceInputs) { - if (inferenceInputs instanceof EmbeddingsInput == false) { - throw createUnsupportedTypeException(inferenceInputs, EmbeddingsInput.class); - } - - return (EmbeddingsInput) inferenceInputs; - } - - private final Supplier> listSupplier; + private final Supplier> inputListSupplier; private final InputType inputType; + private final AtomicBoolean supplierInvoked = new AtomicBoolean(); - public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { - super(false); - this.listSupplier = Objects.requireNonNull(inputSupplier); - this.inputType = inputType; + public EmbeddingsInput(List input, @Nullable InputType inputType) { + this(() -> input, inputType, false); } - public EmbeddingsInput(List input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) { - this(input, chunkingSettings, inputType, false); + public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { + this(() -> input, inputType, stream); } - public EmbeddingsInput(List input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) { - this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream); + public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { + this(inputSupplier, inputType, false); } - public EmbeddingsInput(List input, @Nullable InputType inputType) { - this(input, inputType, false); - } - - public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { + private EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType, boolean stream) { super(stream); - Objects.requireNonNull(input); - this.listSupplier = () -> input; + this.inputListSupplier = Objects.requireNonNull(inputSupplier); this.inputType = inputType; } - public List getInputs() { - return this.listSupplier.get(); - } - - public static EmbeddingsInput fromStrings(List input, @Nullable InputType inputType) { - return new EmbeddingsInput(input, null, inputType); - } - - public List getStringInputs() { - return getInputs().stream().map(ChunkInferenceInput::input).collect(Collectors.toList()); + /** + * Calling this method twice will result in the {@link #inputListSupplier} being invoked twice. In the case where the supplier simply + * returns the list passed into the constructor, this is not a problem, but in the case where a supplier that will chunk the input + * Strings when invoked is passed into the constructor, this will result in multiple copies of the input Strings being created. Calling + * this method twice in a non-production environment will cause an {@link AssertionError} to be thrown. + * + * @return a list of String embedding inputs + */ + public List getInputs() { + assert supplierInvoked.compareAndSet(false, true) : "EmbeddingsInput supplier invoked twice"; + return inputListSupplier.get(); } public InputType getInputType() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java index c39387d647f77..4a485f87858aa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java @@ -52,7 +52,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs(); + var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); var truncatedInput = truncate(docsInput, maxInputTokens); var request = requestCreator.apply(truncatedInput); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 7ad799a613e4a..f0b25bd427b69 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -105,7 +105,7 @@ private static InferenceInputs createInput( if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - yield new EmbeddingsInput(input, null, inputType, stream); + yield new EmbeddingsInput(input, inputType, stream); } default -> throw new ElasticsearchStatusException( Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()), @@ -144,7 +144,7 @@ public void chunkedInfer( } // a non-null query is not supported and is dropped by all providers - doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, inputType, timeout, listener); + doChunkedInfer(model, input, taskSettings, inputType, timeout, listener); } protected abstract void doInfer( @@ -168,7 +168,7 @@ protected abstract void doUnifiedCompletionInfer( protected abstract void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java index 438d31d8dd411..b677ec642075e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -29,7 +30,6 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; @@ -143,7 +143,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java index f11cd41b25aa0..8a77f65592226 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java @@ -71,8 +71,9 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 5383a4bfb2eec..f474850b9f190 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -336,7 +337,7 @@ protected void validateRerankParameters(Boolean returnDocuments, Integer topN, V @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -351,14 +352,14 @@ protected void doChunkedInfer( var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, alibabaCloudSearchModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java index acce3b9a1d6ea..20ff8ce58b550 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java @@ -71,8 +71,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java index 06910611e0a96..387d8b65f40d6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java @@ -56,8 +56,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var serviceSettings = embeddingsModel.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index c2b0ae8e69c37..11204018a5523 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -18,6 +18,7 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -148,7 +149,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -159,14 +160,14 @@ protected void doChunkedInfer( var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, maxBatchSize, baseAmazonBedrockModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 8cf5446f8b6d5..224d62f83f28b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -28,7 +29,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; @@ -232,7 +232,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java index abb9b26a80b0c..afb268ab499a9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java @@ -50,8 +50,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 718757d9e2697..7578aa702ad7c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -135,7 +136,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -145,14 +146,14 @@ protected void doChunkedInfer( var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, baseAzureAiStudioModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java index e98bf731210d7..db38b3fb0def3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java @@ -63,8 +63,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 3d9a3dd516a2d..077e5361dd46f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -282,7 +283,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -296,14 +297,14 @@ protected void doChunkedInfer( var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, azureOpenAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = azureOpenAiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 4963c8646e5d6..2561f198075e2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -285,7 +286,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -300,14 +301,14 @@ protected void doChunkedInfer( var actionCreator = new CohereActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, cohereModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = cohereModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java index 777ddc348bda6..121f0e1e80a96 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java @@ -81,8 +81,8 @@ public ExecutableAction create(CohereEmbeddingsModel model, Map : overriddenModel.getTaskSettings().getInputType(); return switch (overriddenModel.getServiceSettings().getCommonSettings().apiVersion()) { - case V1 -> new CohereV1EmbeddingsRequest(inferenceInputs.getStringInputs(), requestInputType, overriddenModel); - case V2 -> new CohereV2EmbeddingsRequest(inferenceInputs.getStringInputs(), requestInputType, overriddenModel); + case V1 -> new CohereV1EmbeddingsRequest(inferenceInputs.getInputs(), requestInputType, overriddenModel); + case V2 -> new CohereV2EmbeddingsRequest(inferenceInputs.getInputs(), requestInputType, overriddenModel); }; }; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java index 85097b3bf12b0..952d654e1bd01 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java @@ -75,7 +75,7 @@ public void execute( requestParameters = CompletionParameters.of(chatInputs); } else if (inferenceInputs instanceof EmbeddingsInput) { requestParameters = EmbeddingParameters.of( - EmbeddingsInput.of(inferenceInputs), + inferenceInputs.castTo(EmbeddingsInput.class), model.getServiceSettings().getInputTypeTranslator() ); } else { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 7cd069ac2e3e0..fd29b02012185 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -147,7 +148,7 @@ private static RequestParameters createParameters(CustomModel model) { case RERANK -> RerankParameters.of(new QueryAndDocsInputs("test query", List.of("test input"))); case COMPLETION -> CompletionParameters.of(new ChatCompletionInput(List.of("test input"))); case TEXT_EMBEDDING, SPARSE_EMBEDDING -> EmbeddingParameters.of( - new EmbeddingsInput(List.of("test input"), null, null), + new EmbeddingsInput(List.of("test input"), null), model.getServiceSettings().getInputTypeTranslator() ); default -> throw new IllegalStateException( @@ -291,7 +292,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -309,14 +310,14 @@ protected void doChunkedInfer( var manager = CustomRequestManager.of(overriddenModel, getServiceComponents().threadPool()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, customModel.getServiceSettings().getBatchSize(), customModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java index 71eb0ed8e098d..ef91045eb1dab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java @@ -28,7 +28,7 @@ public static EmbeddingParameters of(EmbeddingsInput embeddingsInput, InputTypeT private final InputTypeTranslator translator; private EmbeddingParameters(EmbeddingsInput embeddingsInput, InputTypeTranslator translator) { - super(embeddingsInput.getStringInputs()); + super(embeddingsInput.getInputs()); this.inputType = embeddingsInput.getInputType(); this.translator = translator; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index 8a77efbd604d2..54972e23a25d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -27,7 +28,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; @@ -116,7 +116,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index a0cb0f7ae1249..cc871da8eb860 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; @@ -370,7 +371,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -380,14 +381,14 @@ protected void doChunkedInfer( var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE, denseTextEmbeddingsModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = denseTextEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } return; @@ -397,14 +398,14 @@ protected void doChunkedInfer( var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE, model.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = sparseTextEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } return; @@ -647,7 +648,7 @@ public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() { private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs(); + var inputsAsList = inputs.castTo(EmbeddingsInput.class).getInputs(); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java index ea82eb228dbc8..f34d538c413dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java @@ -69,8 +69,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java index 8b987cd53bc81..749a3277929f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java @@ -97,7 +97,7 @@ public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel m DENSE_TEXT_EMBEDDINGS_HANDLER, (embeddingsInput) -> new ElasticInferenceServiceDenseTextEmbeddingsRequest( model, - embeddingsInput.getStringInputs(), + embeddingsInput.getInputs(), traceContext, extractRequestMetadataFromThreadContext(threadPool.getThreadContext()), embeddingsInput.getInputType() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java index 13e54b9e3e17b..e65bfda857c1b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java @@ -56,8 +56,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 4c8997f35555b..97bd2502d25b6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -350,7 +351,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -359,19 +360,13 @@ protected void doChunkedInfer( GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model; List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, googleAiStudioModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - doInfer( - model, - EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), - taskSettings, - timeout, - request.listener() - ); + doInfer(model, new EmbeddingsInput(request.batch().inputs(), inputType), taskSettings, timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java index 2dc60ef114459..90d7a0b1b0a11 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java @@ -64,8 +64,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 4e58e063eeebc..41678689e8b9d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -275,7 +276,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -285,14 +286,14 @@ protected void doChunkedInfer( var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, googleVertexAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = googleVertexAiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java index 7bb140e91ec5d..833f9bc6b347b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java @@ -62,8 +62,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); - var truncatedInput = truncate(docsInput, model.getTokenLimit()); + List inputs = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); + var truncatedInput = truncate(inputs, model.getTokenLimit()); var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index e0ad3f7460477..d0a98d8252923 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -14,6 +14,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -151,7 +152,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -166,14 +167,14 @@ protected void doChunkedInfer( var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, huggingFaceModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = huggingFaceModel.accept(actionCreator); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 5f9288bb99c24..775a4e90ae034 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -104,7 +104,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -115,22 +115,31 @@ protected void doChunkedInfer( ); // TODO chunking sparse embeddings not implemented - doInfer(model, inputs, taskSettings, timeout, inferListener); + doInfer( + model, + new EmbeddingsInput(inputs.stream().map(ChunkInferenceInput::input).toList(), inputType), + taskSettings, + timeout, + inferListener + ); } - private static List translateToChunkedResults(EmbeddingsInput inputs, InferenceServiceResults inferenceResults) { + private static List translateToChunkedResults( + List inputs, + InferenceServiceResults inferenceResults + ) { if (inferenceResults instanceof TextEmbeddingFloatResults textEmbeddingResults) { - validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs.getInputs()), textEmbeddingResults.embeddings().size()); + validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs), textEmbeddingResults.embeddings().size()); - var results = new ArrayList(inputs.getInputs().size()); + var results = new ArrayList(inputs.size()); - for (int i = 0; i < inputs.getInputs().size(); i++) { + for (int i = 0; i < inputs.size(); i++) { results.add( new ChunkedInferenceEmbedding( List.of( new EmbeddingResults.Chunk( textEmbeddingResults.embeddings().get(i), - new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).input().length()) + new ChunkedInference.TextOffset(0, inputs.get(i).input().length()) ) ) ) @@ -138,7 +147,7 @@ private static List translateToChunkedResults(EmbeddingsInput } return results; } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = ChunkInferenceInput.inputs(EmbeddingsInput.of(inputs).getInputs()); + var inputsAsList = ChunkInferenceInput.inputs(inputs); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java index c7a556321148e..92434d371c7e8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java @@ -55,7 +55,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); + List docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); execute( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 9617bff0d3f3d..8cdc8cd182425 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -351,7 +352,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, @@ -360,13 +361,16 @@ protected void doChunkedInfer( IbmWatsonxModel ibmWatsonxModel = (IbmWatsonxModel) model; var batchedRequests = new EmbeddingRequestChunker<>( - input.getInputs(), + input, EMBEDDING_MAX_BATCH_SIZE, ibmWatsonxModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); + + var actionCreator = getActionCreator(getSender(), getServiceComponents()); + for (var request : batchedRequests) { - var action = ibmWatsonxModel.accept(getActionCreator(getSender(), getServiceComponents()), taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + var action = ibmWatsonxModel.accept(actionCreator, taskSettings); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java index 083690a894c00..3e3918acb78dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java @@ -52,8 +52,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); JinaAIEmbeddingsRequest request = new JinaAIEmbeddingsRequest(docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index bed21c9ccb8bf..f6bd954617b76 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -276,7 +277,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -291,14 +292,14 @@ protected void doChunkedInfer( var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, jinaaiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = jinaaiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java index 829dbe0a18955..a74f3202e5fb4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -185,7 +186,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -200,14 +201,14 @@ protected void doChunkedInfer( var actionCreator = new LlamaActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, llamaModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = llamaModel.accept(actionCreator); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java index 52e284ba7ccca..f647338ba3110 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java @@ -73,7 +73,7 @@ public ExecutableAction create(LlamaEmbeddingsModel model) { EMBEDDINGS_HANDLER, embeddingsInput -> new LlamaEmbeddingsRequest( serviceComponents.truncator(), - truncate(embeddingsInput.getStringInputs(), model.getServiceSettings().maxInputTokens()), + truncate(embeddingsInput.getInputs(), model.getServiceSettings().maxInputTokens()), model ), EmbeddingsInput.class diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java index 391a549df924a..ea31435780b96 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java @@ -61,7 +61,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); + List docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 25adb439382c9..b114aa8081b9c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -150,7 +151,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -160,14 +161,14 @@ protected void doChunkedInfer( if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) { List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, MistralConstants.MAX_BATCH_SIZE, mistralEmbeddingsModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = mistralEmbeddingsModel.accept(actionCreator); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index d2b7dcc527aaa..ae49f5dcef13b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -336,7 +337,7 @@ public void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -351,14 +352,14 @@ protected void doChunkedInfer( var actionCreator = new OpenAiActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, openAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = openAiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 75d568c6477fd..c69aeec203e4c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -307,7 +308,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -322,14 +323,14 @@ protected void doChunkedInfer( var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, getBatchSize(voyageaiModel), voyageaiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = voyageaiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java index 03753835177cb..5bf9bd66def2f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java @@ -57,7 +57,7 @@ public ExecutableAction create(VoyageAIEmbeddingsModel model, Map new VoyageAIEmbeddingsRequest( - embeddingsInput.getStringInputs(), + embeddingsInput.getInputs(), embeddingsInput.getInputType(), overriddenModel ), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java index 0bfa640b0cded..2682f74d0ebf9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; @@ -65,8 +64,6 @@ public void testOneInputIsValid() { public void testMoreThanOneInput() { var badInput = mock(EmbeddingsInput.class); - var input = List.of(new ChunkInferenceInput("one"), new ChunkInferenceInput("two")); - when(badInput.getInputs()).thenReturn(input); when(badInput.isSingleInput()).thenReturn(false); var actualException = new AtomicReference(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java new file mode 100644 index 0000000000000..d6ba10b1932dc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.is; + +public class EmbeddingsInputTests extends ESTestCase { + public void testCallingGetInputs_invokesSupplier() { + AtomicBoolean invoked = new AtomicBoolean(); + final List list = List.of("input1", "input2"); + Supplier> supplier = () -> { + invoked.set(true); + return list; + }; + EmbeddingsInput input = new EmbeddingsInput(supplier, null); + // Ensure we don't invoke the supplier until we call getInputs() + assertThat(invoked.get(), is(false)); + + assertThat(input.getInputs(), is(list)); + assertThat(invoked.get(), is(true)); + } + + public void testCallingGetInputsTwice_throws() { + Supplier> supplier = () -> List.of("input"); + EmbeddingsInput input = new EmbeddingsInput(supplier, null); + input.getInputs(); + var exception = expectThrows(AssertionError.class, input::getInputs); + assertThat(exception.getMessage(), is("EmbeddingsInput supplier invoked twice")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 9c8fd66912999..f56b54ecc916d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -143,7 +142,7 @@ public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception PlainActionFuture listener = new PlainActionFuture<>(); sender.send( OpenAiEmbeddingsRequestManagerTests.makeCreator(getUrl(webServer), null, "key", "model", null, threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null), + new EmbeddingsInput(List.of("abc"), null), null, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java index e91b0b3451a77..9e7215ecc0a94 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java @@ -17,7 +17,7 @@ public class InferenceInputsTests extends ESTestCase { public void testCastToSucceeds() { - InferenceInputs inputs = new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull(), false); + InferenceInputs inputs = new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()); assertThat(inputs.castTo(EmbeddingsInput.class), Matchers.instanceOf(EmbeddingsInput.class)); var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null); @@ -29,7 +29,7 @@ public void testCastToSucceeds() { } public void testCastToFails() { - InferenceInputs inputs = new EmbeddingsInput(List.of(), null, false); + InferenceInputs inputs = new EmbeddingsInput(List.of(), null); var exception = expectThrows(IllegalArgumentException.class, () -> inputs.castTo(QueryAndDocsInputs.class)); assertThat( exception.getMessage(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java index 4958dd8f90bc1..81557c0219a06 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.Scheduler; @@ -62,7 +61,7 @@ public void testExecuting_DoesNotCallOnFailureForTimeout_AfterIllegalArgumentExc var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), mockThreadPool, listener @@ -82,7 +81,7 @@ public void testRequest_ReturnsTimeoutException() { PlainActionFuture listener = new PlainActionFuture<>(); var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -106,7 +105,7 @@ public void testRequest_DoesNotCallOnFailureTwiceWhenTimingOut() throws Exceptio var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -135,7 +134,7 @@ public void testRequest_DoesNotCallOnResponseAfterTimingOut() throws Exception { var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -162,7 +161,7 @@ public void testRequest_DoesNotCallOnFailureForTimeout_AfterAlreadyCallingOnResp var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), mockThreadPool, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index ca107f0843b41..69e5228a927e7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -23,7 +24,6 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -215,7 +215,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchCompletionRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchCompletionRequestManagerTests.java index 5c6be8b2cc86e..d054997be7d3b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchCompletionRequestManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchCompletionRequestManagerTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.TaskType; @@ -75,7 +74,7 @@ public void testExecute_throwsElasticsearchStatusException_whenNumberOfInputsIsE } public void testExecute_throwsIllegalArgumentException_whenInputIsNotChatCompletion() { - var inputs = new EmbeddingsInput(List.of(new ChunkInferenceInput("input1")), InputType.SEARCH); + var inputs = new EmbeddingsInput(List.of("input1"), InputType.SEARCH); RequestSender mockSender = mock(RequestSender.class); PlainActionFuture listener = new PlainActionFuture<>(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java index 7cf7878e94580..b09fbf43a8ca4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java @@ -95,11 +95,7 @@ public void testExecute_withTextEmbeddingsAction_Success() { var action = createTextEmbeddingsAction(); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute( - new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(values)))); @@ -110,11 +106,7 @@ public void testExecute_withTextEmbeddingsAction_ListenerThrowsElasticsearchExce var action = createTextEmbeddingsAction(); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute( - new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(thrownException.getMessage(), is("error")); @@ -125,11 +117,7 @@ public void testExecute_withTextEmbeddingsAction_ListenerThrowsInternalServerErr var action = createTextEmbeddingsAction(); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute( - new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat(thrownException.getMessage(), is("Failed to send AlibabaCloud Search text embeddings request. Cause: error")); @@ -152,11 +140,7 @@ public void testExecute_withSparseEmbeddingsAction_Success() { var action = createSparseEmbeddingsAction(); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute( - new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat( @@ -174,11 +158,7 @@ public void testExecute_withSparseEmbeddingsAction_ListenerThrowsElasticsearchEx var action = createSparseEmbeddingsAction(); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute( - new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(thrownException.getMessage(), is("error")); @@ -189,11 +169,7 @@ public void testExecute_withSparseEmbeddingsAction_ListenerThrowsInternalServerE var action = createSparseEmbeddingsAction(); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute( - new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat(thrownException.getMessage(), is("Failed to send AlibabaCloud Search sparse embeddings request. Cause: error")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java index ada5918a63a95..5dd42dc66485f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java @@ -74,7 +74,7 @@ public void testEmbeddingsRequestAction_Titan() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -112,7 +112,7 @@ public void testEmbeddingsRequestAction_Cohere() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); @@ -145,7 +145,7 @@ public void testEmbeddingsRequestAction_HandlesException() throws IOException { ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(sender.sendCount(), is(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java index ac82367af6865..2e8d0bbd2ea14 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java @@ -12,7 +12,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; @@ -78,10 +77,10 @@ public void send( ActionListener listener ) { sendCounter++; - if (inferenceInputs instanceof EmbeddingsInput docsInput) { - inputs.add(ChunkInferenceInput.inputs(docsInput.getInputs())); - if (docsInput.getInputType() != null) { - inputTypes.add(docsInput.getInputType()); + if (inferenceInputs instanceof EmbeddingsInput embeddingsInput) { + inputs.add(embeddingsInput.getInputs()); + if (embeddingsInput.getInputType() != null) { + inputTypes.add(embeddingsInput.getInputType()); } } else if (inferenceInputs instanceof ChatCompletionInput chatCompletionInput) { inputs.add(chatCompletionInput.getInputs()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java index bbf2306a3e352..fed601805a748 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -112,7 +111,7 @@ public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws threadPool, new TimeValue(30, TimeUnit.SECONDS) ); - sender.send(requestManager, new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null), null, listener); + sender.send(requestManager, new EmbeddingsInput(List.of("abc"), null), null, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.456F, 0.678F, 0.789F })))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java index b7f6d97a27f96..daceedd2b8207 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java @@ -107,7 +107,7 @@ public void testEmbeddingsRequestAction() throws IOException { final var action = creator.create(model, Map.of()); final PlainActionFuture listener = new PlainActionFuture<>(); final var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); final var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java index 7212490a568c1..c1d69580af8fd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java @@ -119,7 +119,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -170,7 +170,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOExcepti var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -222,7 +222,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var failureCauseMessage = "Required [data]"; var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -296,7 +296,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -373,7 +373,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -433,11 +433,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute( - new EmbeddingsInput(List.of("super long input"), null, inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("super long input"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java index bb0bde3b29362..0c6b7f62a96b3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -117,11 +116,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -149,11 +144,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -174,11 +165,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -199,11 +186,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -218,11 +201,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java index e6fd4f9e8caf0..8ab76bb728802 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java @@ -119,7 +119,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java index 02c36a848b29d..ca068d3e1859d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java @@ -124,7 +124,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); InputType inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -207,7 +207,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I ); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -262,7 +262,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -282,7 +282,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -302,7 +302,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -316,7 +316,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -330,7 +330,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java index 145a305d5f5bc..6a45a412de18c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java @@ -77,7 +77,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() { var listener = new PlainActionFuture(); var manager = CustomRequestManager.of(model, threadPool); - manager.execute(new EmbeddingsInput(List.of("abc", "123"), null, null), mock(RequestSender.class), () -> false, listener); + manager.execute(new EmbeddingsInput(List.of("abc", "123"), null), mock(RequestSender.class), () -> false, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TimeValue.timeValueSeconds(30))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index aef1a6d3bbbc1..8b35979c3daf5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -77,10 +77,7 @@ public void testCreateRequest() throws IOException { ); var request = new CustomRequest( - EmbeddingParameters.of( - new EmbeddingsInput(List.of("abc", "123"), null, null), - model.getServiceSettings().getInputTypeTranslator() - ), + EmbeddingParameters.of(new EmbeddingsInput(List.of("abc", "123"), null), model.getServiceSettings().getInputTypeTranslator()), model ); var httpRequest = request.createHttpRequest(); @@ -141,7 +138,7 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() throws IOEx var request = new CustomRequest( EmbeddingParameters.of( - new EmbeddingsInput(List.of("abc", "123"), null, InputType.INGEST), + new EmbeddingsInput(List.of("abc", "123"), InputType.INGEST), model.getServiceSettings().getInputTypeTranslator() ), model @@ -197,7 +194,7 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws var request = new CustomRequest( EmbeddingParameters.of( - new EmbeddingsInput(List.of("abc", "123"), null, InputType.SEARCH), + new EmbeddingsInput(List.of("abc", "123"), InputType.SEARCH), model.getServiceSettings().getInputTypeTranslator() ), model diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java index 5231802ef4b92..d6dc03c9daf75 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java @@ -21,7 +21,7 @@ public class EmbeddingParametersTests extends ESTestCase { public void testTaskTypeParameters_UsesDefaultValue() { var parameters = EmbeddingParameters.of( - new EmbeddingsInput(List.of("input"), null, InputType.INGEST), + new EmbeddingsInput(List.of("input"), InputType.INGEST), new InputTypeTranslator(Map.of(), "default") ); @@ -30,7 +30,7 @@ public void testTaskTypeParameters_UsesDefaultValue() { public void testTaskTypeParameters_UsesMappedValue() { var parameters = EmbeddingParameters.of( - new EmbeddingsInput(List.of("input"), null, InputType.INGEST), + new EmbeddingsInput(List.of("input"), InputType.INGEST), new InputTypeTranslator(Map.of(InputType.INGEST, "ingest_value"), "default") ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java index 67bfe25fbdb6a..e53add6733aca 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java @@ -65,7 +65,7 @@ public void testFromTextEmbeddingResponse() throws IOException { new TextEmbeddingResponseParser("$.result.embeddings[*].embedding", CustomServiceEmbeddingType.FLOAT) ); var request = new CustomRequest( - EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()), + EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null), model.getServiceSettings().getInputTypeTranslator()), model ); InferenceServiceResults results = CustomResponseEntity.fromResponse( @@ -116,7 +116,7 @@ public void testFromSparseEmbeddingResponse() throws IOException { ) ); var request = new CustomRequest( - EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()), + EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null), model.getServiceSettings().getInputTypeTranslator()), model ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index 4ab11ffd6ee34..7ee6d817f899c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -102,7 +102,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -163,7 +163,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -291,7 +291,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world", "second text"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world", "second text"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -347,7 +347,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_W PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("search query"), null, InputType.SEARCH), + new EmbeddingsInput(List.of("search query"), InputType.SEARCH), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -402,7 +402,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForDenseTextEmbeddingsAction PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -443,7 +443,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_E var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of(), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -492,7 +492,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -556,7 +556,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java index e74aa310281a4..a36f06841ddcf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java @@ -108,7 +108,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of(input), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of(input), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -163,7 +163,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -187,7 +187,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -205,7 +205,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java index 8330fb58aa746..cf458fab6eb21 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java @@ -75,7 +75,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -99,7 +99,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -117,7 +117,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java index c4706cd545bc5..541d805943647 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java @@ -71,7 +71,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -91,7 +91,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -105,7 +105,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index ec433d59ae586..70632a439fdea 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -101,7 +101,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -172,7 +172,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -224,7 +224,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -286,7 +286,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -409,7 +409,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -471,7 +471,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("123456"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("123456"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java index 7169466092721..a21d2d9f19e1d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java @@ -63,7 +63,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderThrows() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -87,7 +87,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -108,7 +108,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java index f231da916a41d..c2ec051fd7a42 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java @@ -114,7 +114,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(input), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(input), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -144,7 +144,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -173,7 +173,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -197,7 +197,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java index b59ee5105f318..709a4a0630ba0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java @@ -232,7 +232,7 @@ private PlainActionFuture createEmbeddingsFuture(Sender PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java index 6b63212a308a1..df4e5cf4e3822 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java @@ -109,7 +109,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -166,7 +166,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -222,7 +222,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -285,7 +285,7 @@ public void testCreate_OpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() th PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -625,7 +625,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -712,7 +712,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -784,7 +784,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("super long input"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("super long input"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java index 3fd3ccdfbbfbb..6758f66f2917e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; @@ -115,7 +114,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -155,7 +154,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -179,7 +178,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -203,7 +202,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -221,7 +220,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -239,7 +238,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java index fdc8ae5ffedee..e2802c3569d86 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -111,11 +110,7 @@ public void testCreate_VoyageAIEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java index c5aa5182373a8..ba97ba3b70d00 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -121,11 +120,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -222,11 +217,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -323,11 +314,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForBinaryResponseType() throws PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -396,7 +383,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -420,7 +407,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -438,7 +425,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -461,7 +448,7 @@ private ExecutableAction createAction( threadPool, model, EMBEDDINGS_HANDLER, - (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getStringInputs(), embeddingsInput.getInputType(), model), + (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getInputs(), embeddingsInput.getInputType(), model), EmbeddingsInput.class );