diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 2df2f1e62f89a..da071442d6c1b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -37,16 +37,16 @@ * a single large input that has been chunked may spread over * multiple batches. * - * The final aspect it to gather the responses from the batch + * The final aspect is to gather the responses from the batch * processing and map the results back to the original element * in the input list. */ public class EmbeddingRequestChunker> { // Visible for testing - record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List inputs) { + record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, String input) { public String chunkText() { - return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end()); + return input.substring(chunk.start(), chunk.end()); } } @@ -60,7 +60,7 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings); + String inputString = inputs.get(inputIndex).input(); + List chunks = chunker.chunk(inputString, chunkingSettings); int resultCount = Math.min(chunks.size(), MAX_CHUNKS); resultEmbeddings.add(new AtomicReferenceArray<>(resultCount)); resultOffsetStarts.add(new ArrayList<>(resultCount)); @@ -129,7 +130,7 @@ public EmbeddingRequestChunker( } else { resultOffsetEnds.getLast().set(targetChunkIndex, chunks.get(chunkIndex).end()); } - allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputs)); + allRequests.add(new Request(inputIndex, targetChunkIndex, chunks.get(chunkIndex), inputString)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java index 55cdb7207e25d..f9fd3a2011ee0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java @@ -8,63 +8,47 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ChunkInferenceInput; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import java.util.List; import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; -import java.util.stream.Collectors; public class EmbeddingsInput extends InferenceInputs { - - public static EmbeddingsInput of(InferenceInputs inferenceInputs) { - if (inferenceInputs instanceof EmbeddingsInput == false) { - throw createUnsupportedTypeException(inferenceInputs, EmbeddingsInput.class); - } - - return (EmbeddingsInput) inferenceInputs; - } - - private final Supplier> listSupplier; + private final Supplier> inputListSupplier; private final InputType inputType; + private final AtomicBoolean supplierInvoked = new AtomicBoolean(); - public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { - super(false); - this.listSupplier = Objects.requireNonNull(inputSupplier); - this.inputType = inputType; + public EmbeddingsInput(List input, @Nullable InputType inputType) { + this(() -> input, inputType, false); } - public EmbeddingsInput(List input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) { - this(input, chunkingSettings, inputType, false); + public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { + this(() -> input, inputType, stream); } - public EmbeddingsInput(List input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType, boolean stream) { - this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).toList(), inputType, stream); + public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { + this(inputSupplier, inputType, false); } - public EmbeddingsInput(List input, @Nullable InputType inputType) { - this(input, inputType, false); - } - - public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { + private EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType, boolean stream) { super(stream); - Objects.requireNonNull(input); - this.listSupplier = () -> input; + this.inputListSupplier = Objects.requireNonNull(inputSupplier); this.inputType = inputType; } - public List getInputs() { - return this.listSupplier.get(); - } - - public static EmbeddingsInput fromStrings(List input, @Nullable InputType inputType) { - return new EmbeddingsInput(input, null, inputType); - } - - public List getStringInputs() { - return getInputs().stream().map(ChunkInferenceInput::input).collect(Collectors.toList()); + /** + * Calling this method twice will result in the {@link #inputListSupplier} being invoked twice. In the case where the supplier simply + * returns the list passed into the constructor, this is not a problem, but in the case where a supplier that will chunk the input + * Strings when invoked is passed into the constructor, this will result in multiple copies of the input Strings being created. Calling + * this method twice in a non-production environment will cause an {@link AssertionError} to be thrown. + * + * @return a list of String embedding inputs + */ + public List getInputs() { + assert supplierInvoked.compareAndSet(false, true) : "EmbeddingsInput supplier invoked twice"; + return inputListSupplier.get(); } public InputType getInputType() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java index c39387d647f77..4a485f87858aa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java @@ -52,7 +52,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs(); + var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); var truncatedInput = truncate(docsInput, maxInputTokens); var request = requestCreator.apply(truncatedInput); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 657834e6831ff..0147c62823f0d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -101,7 +101,7 @@ private static InferenceInputs createInput( if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - yield new EmbeddingsInput(input, null, inputType, stream); + yield new EmbeddingsInput(input, inputType, stream); } default -> throw new ElasticsearchStatusException( Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()), @@ -140,7 +140,7 @@ public void chunkedInfer( } // a non-null query is not supported and is dropped by all providers - doChunkedInfer(model, new EmbeddingsInput(input, inputType), taskSettings, inputType, timeout, listener); + doChunkedInfer(model, input, taskSettings, inputType, timeout, listener); } protected abstract void doInfer( @@ -164,7 +164,7 @@ protected abstract void doUnifiedCompletionInfer( protected abstract void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java new file mode 100644 index 0000000000000..b677ec642075e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ai21/Ai21Service.java @@ -0,0 +1,308 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ai21; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.ai21.action.Ai21ActionCreator; +import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionModel; +import org.elasticsearch.xpack.inference.services.ai21.completion.Ai21ChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.ai21.request.Ai21ChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +/** + * Ai21Service is an implementation of the SenderService that handles inference tasks + * using AI21 models. It supports completion and chat completion tasks. + * The service uses Ai21ActionCreator to create actions for executing inference requests. + */ +public class Ai21Service extends SenderService { + public static final String NAME = "ai21"; + + private static final String SERVICE_NAME = "AI21"; + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); + private static final ResponseHandler CHAT_COMPLETION_HANDLER = new Ai21ChatCompletionResponseHandler( + "ai21 chat completions", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + private static final TransportVersion ML_INFERENCE_AI21_COMPLETION_ADDED = TransportVersion.fromName( + "ml_inference_ai21_completion_added" + ); + + public Ai21Service( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public Ai21Service(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); + } + + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + var actionCreator = new Ai21ActionCreator(getSender(), getServiceComponents()); + + if (Objects.requireNonNull(model) instanceof Ai21ChatCompletionModel ai21ChatCompletionModel) { + ai21ChatCompletionModel.accept(actionCreator).execute(inputs, timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { + ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException); + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof Ai21ChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + Ai21ChatCompletionModel ai21ChatCompletionModel = (Ai21ChatCompletionModel) model; + var overriddenModel = Ai21ChatCompletionModel.of(ai21ChatCompletionModel, inputs.getRequest()); + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + overriddenModel, + CHAT_COMPLETION_HANDLER, + unifiedChatInput -> new Ai21ChatCompletionRequest(unifiedChatInput, overriddenModel), + UnifiedChatInput.class + ); + var errorMessage = Ai21ActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId()); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + + action.execute(inputs, timeout, listener); + } + + @Override + protected void doChunkedInfer( + Model model, + List inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + throw new UnsupportedOperationException("AI21 service does not support chunked inference"); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + Ai21Model model = createModel( + modelId, + taskType, + serviceSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public Ai21Model parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + @Override + public Ai21Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + return createModelFromPersistent(modelId, taskType, serviceSettingsMap, null, parsePersistedConfigErrorMsg(modelId, NAME)); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return ML_INFERENCE_AI21_COMPLETION_ADDED; + } + + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); + } + + private static Ai21Model createModel( + String modelId, + TaskType taskType, + Map serviceSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + switch (taskType) { + case CHAT_COMPLETION, COMPLETION: + return new Ai21ChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context); + default: + throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + } + } + + private Ai21Model createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + /** + * Configuration class for the AI21 inference service. + * It provides the settings and configurations required for the service. + */ + public static class Configuration { + public static InferenceServiceConfiguration get() { + return CONFIGURATION.getOrCompute(); + } + + private static final LazyInitializable CONFIGURATION = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + ServiceFields.MODEL_ID, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription( + "Refer to the AI21 models documentation for the list of available inference models." + ) + .setLabel("Model") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java index f11cd41b25aa0..8a77f65592226 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchEmbeddingsRequestManager.java @@ -71,8 +71,9 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 7897317736c72..ff6482a795f77 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -321,7 +322,7 @@ protected void validateRerankParameters(Boolean returnDocuments, Integer topN, V @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -336,14 +337,14 @@ protected void doChunkedInfer( var actionCreator = new AlibabaCloudSearchActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, alibabaCloudSearchModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java index acce3b9a1d6ea..20ff8ce58b550 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchSparseRequestManager.java @@ -71,8 +71,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java index 06910611e0a96..387d8b65f40d6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockEmbeddingsRequestManager.java @@ -56,8 +56,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var serviceSettings = embeddingsModel.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 591607953ea1a..04122127b72d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -136,7 +137,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -147,14 +148,14 @@ protected void doChunkedInfer( var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, maxBatchSize, baseAmazonBedrockModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 791518ccc9168..8be92ce68abb7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -26,7 +27,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; @@ -222,7 +222,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java index abb9b26a80b0c..afb268ab499a9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioEmbeddingsRequestManager.java @@ -50,8 +50,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 04883f23b947f..34ebac37a2a55 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -123,7 +124,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -133,14 +134,14 @@ protected void doChunkedInfer( var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, baseAzureAiStudioModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java index e98bf731210d7..db38b3fb0def3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiEmbeddingsRequestManager.java @@ -63,8 +63,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index e9ff97c1ba725..66fb753d76b7d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -272,7 +273,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -286,14 +287,14 @@ protected void doChunkedInfer( var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, azureOpenAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = azureOpenAiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index c2f1221763165..6c87ee5b86d68 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -274,7 +275,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -289,14 +290,14 @@ protected void doChunkedInfer( var actionCreator = new CohereActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, cohereModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = cohereModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java index 777ddc348bda6..121f0e1e80a96 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java @@ -81,8 +81,8 @@ public ExecutableAction create(CohereEmbeddingsModel model, Map : overriddenModel.getTaskSettings().getInputType(); return switch (overriddenModel.getServiceSettings().getCommonSettings().apiVersion()) { - case V1 -> new CohereV1EmbeddingsRequest(inferenceInputs.getStringInputs(), requestInputType, overriddenModel); - case V2 -> new CohereV2EmbeddingsRequest(inferenceInputs.getStringInputs(), requestInputType, overriddenModel); + case V1 -> new CohereV1EmbeddingsRequest(inferenceInputs.getInputs(), requestInputType, overriddenModel); + case V2 -> new CohereV2EmbeddingsRequest(inferenceInputs.getInputs(), requestInputType, overriddenModel); }; }; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java index 6e2400c33d6a2..1c08fdb9f4285 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java @@ -75,7 +75,7 @@ public void execute( requestParameters = CompletionParameters.of(chatInputs); } else if (inferenceInputs instanceof EmbeddingsInput) { requestParameters = EmbeddingParameters.of( - EmbeddingsInput.of(inferenceInputs), + inferenceInputs.castTo(EmbeddingsInput.class), model.getServiceSettings().getInputTypeTranslator() ); } else { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 4a4166cf65ed3..16acfaa1af430 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -16,6 +16,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -136,7 +137,7 @@ private static RequestParameters createParameters(CustomModel model) { case RERANK -> RerankParameters.of(new QueryAndDocsInputs("test query", List.of("test input"))); case COMPLETION -> CompletionParameters.of(new ChatCompletionInput(List.of("test input"))); case TEXT_EMBEDDING, SPARSE_EMBEDDING -> EmbeddingParameters.of( - new EmbeddingsInput(List.of("test input"), null, null), + new EmbeddingsInput(List.of("test input"), null), model.getServiceSettings().getInputTypeTranslator() ); default -> throw new IllegalStateException( @@ -280,7 +281,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -298,14 +299,14 @@ protected void doChunkedInfer( var manager = CustomRequestManager.of(overriddenModel, getServiceComponents().threadPool()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, customModel.getServiceSettings().getBatchSize(), customModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java index 71eb0ed8e098d..ef91045eb1dab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java @@ -28,7 +28,7 @@ public static EmbeddingParameters of(EmbeddingsInput embeddingsInput, InputTypeT private final InputTypeTranslator translator; private EmbeddingParameters(EmbeddingsInput embeddingsInput, InputTypeTranslator translator) { - super(embeddingsInput.getStringInputs()); + super(embeddingsInput.getInputs()); this.inputType = embeddingsInput.getInputType(); this.translator = translator; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index 56719199e094f..22638ed7463a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -25,7 +26,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; @@ -106,7 +106,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 4f8d0d01861cd..4c14b507ecb2e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySecretSettings; @@ -351,7 +352,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -361,14 +362,14 @@ protected void doChunkedInfer( var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE, denseTextEmbeddingsModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = denseTextEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } return; @@ -378,14 +379,14 @@ protected void doChunkedInfer( var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE, model.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = sparseTextEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } return; @@ -618,7 +619,7 @@ public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() { private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs(); + var inputsAsList = inputs.castTo(EmbeddingsInput.class).getInputs(); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java index ea82eb228dbc8..f34d538c413dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsRequestManager.java @@ -69,8 +69,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java index 8b987cd53bc81..749a3277929f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java @@ -97,7 +97,7 @@ public ExecutableAction create(ElasticInferenceServiceDenseTextEmbeddingsModel m DENSE_TEXT_EMBEDDINGS_HANDLER, (embeddingsInput) -> new ElasticInferenceServiceDenseTextEmbeddingsRequest( model, - embeddingsInput.getStringInputs(), + embeddingsInput.getInputs(), traceContext, extractRequestMetadataFromThreadContext(threadPool.getThreadContext()), embeddingsInput.getInputType() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java index 13e54b9e3e17b..e65bfda857c1b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioEmbeddingsRequestManager.java @@ -56,8 +56,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 9841ea64370c3..8665e744b76e1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -340,7 +341,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -349,19 +350,13 @@ protected void doChunkedInfer( GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model; List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, googleAiStudioModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - doInfer( - model, - EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), - taskSettings, - timeout, - request.listener() - ); + doInfer(model, new EmbeddingsInput(request.batch().inputs(), inputType), taskSettings, timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java index 2dc60ef114459..90d7a0b1b0a11 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiEmbeddingsRequestManager.java @@ -64,8 +64,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 3b59e999125e5..6a725a0395f7b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -264,7 +265,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -274,14 +275,14 @@ protected void doChunkedInfer( var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, googleVertexAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = googleVertexAiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java index 7bb140e91ec5d..833f9bc6b347b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java @@ -62,8 +62,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); - var truncatedInput = truncate(docsInput, model.getTokenLimit()); + List inputs = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); + var truncatedInput = truncate(inputs, model.getTokenLimit()); var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index d10fb77290c6b..4c110505706c7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -140,7 +141,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -155,14 +156,14 @@ protected void doChunkedInfer( var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, huggingFaceModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = huggingFaceModel.accept(actionCreator); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index e61995aac91f3..934ace507219f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -94,7 +94,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -105,22 +105,31 @@ protected void doChunkedInfer( ); // TODO chunking sparse embeddings not implemented - doInfer(model, inputs, taskSettings, timeout, inferListener); + doInfer( + model, + new EmbeddingsInput(inputs.stream().map(ChunkInferenceInput::input).toList(), inputType), + taskSettings, + timeout, + inferListener + ); } - private static List translateToChunkedResults(EmbeddingsInput inputs, InferenceServiceResults inferenceResults) { + private static List translateToChunkedResults( + List inputs, + InferenceServiceResults inferenceResults + ) { if (inferenceResults instanceof TextEmbeddingFloatResults textEmbeddingResults) { - validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs.getInputs()), textEmbeddingResults.embeddings().size()); + validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs), textEmbeddingResults.embeddings().size()); - var results = new ArrayList(inputs.getInputs().size()); + var results = new ArrayList(inputs.size()); - for (int i = 0; i < inputs.getInputs().size(); i++) { + for (int i = 0; i < inputs.size(); i++) { results.add( new ChunkedInferenceEmbedding( List.of( new EmbeddingResults.Chunk( textEmbeddingResults.embeddings().get(i), - new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).input().length()) + new ChunkedInference.TextOffset(0, inputs.get(i).input().length()) ) ) ) @@ -128,7 +137,7 @@ private static List translateToChunkedResults(EmbeddingsInput } return results; } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = ChunkInferenceInput.inputs(EmbeddingsInput.of(inputs).getInputs()); + var inputsAsList = ChunkInferenceInput.inputs(inputs); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java index b7c679d3cda54..ebdb20d2bdcc5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxEmbeddingsRequestManager.java @@ -55,7 +55,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); + List docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); execute( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 7dfb0002bb062..e502d830c6ee4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -297,7 +298,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, @@ -306,13 +307,16 @@ protected void doChunkedInfer( IbmWatsonxModel ibmWatsonxModel = (IbmWatsonxModel) model; var batchedRequests = new EmbeddingRequestChunker<>( - input.getInputs(), + input, EMBEDDING_MAX_BATCH_SIZE, ibmWatsonxModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); + + var actionCreator = getActionCreator(getSender(), getServiceComponents()); + for (var request : batchedRequests) { - var action = ibmWatsonxModel.accept(getActionCreator(getSender(), getServiceComponents()), taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + var action = ibmWatsonxModel.accept(actionCreator, taskSettings); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java index 083690a894c00..3e3918acb78dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIEmbeddingsRequestManager.java @@ -52,8 +52,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); + EmbeddingsInput input = inferenceInputs.castTo(EmbeddingsInput.class); + List docsInput = input.getInputs(); InputType inputType = input.getInputType(); JinaAIEmbeddingsRequest request = new JinaAIEmbeddingsRequest(docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index c2e88cb6cdc7c..8c5ac8f3b2a66 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -265,7 +266,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -280,14 +281,14 @@ protected void doChunkedInfer( var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, jinaaiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = jinaaiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java new file mode 100644 index 0000000000000..a74f3202e5fb4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/LlamaService.java @@ -0,0 +1,418 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.llama.action.LlamaActionCreator; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.inference.TaskType.CHAT_COMPLETION; +import static org.elasticsearch.inference.TaskType.COMPLETION; +import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +/** + * LlamaService is an inference service for Llama models, supporting text embedding and chat completion tasks. + * It extends SenderService to handle HTTP requests and responses for Llama models. + */ +public class LlamaService extends SenderService { + public static final String NAME = "llama"; + private static final String SERVICE_NAME = "Llama"; + private static final TransportVersion ML_INFERENCE_LLAMA_ADDED = TransportVersion.fromName("ml_inference_llama_added"); + /** + * The optimal batch size depends on the hardware the model is deployed on. + * For Llama use a conservatively small max batch size as it is + * unknown how the model is deployed + */ + static final int EMBEDDING_MAX_BATCH_SIZE = 20; + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of(TEXT_EMBEDDING, COMPLETION, CHAT_COMPLETION); + private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new LlamaChatCompletionResponseHandler( + "llama chat completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + /** + * Constructor for creating a LlamaService with specified HTTP request sender factory and service components. + * + * @param factory the factory to create HTTP request senders + * @param serviceComponents the components required for the inference service + * @param context the context for the inference service factory + */ + public LlamaService( + HttpRequestSender.Factory factory, + ServiceComponents serviceComponents, + InferenceServiceExtension.InferenceServiceFactoryContext context + ) { + this(factory, serviceComponents, context.clusterService()); + } + + public LlamaService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) { + super(factory, serviceComponents, clusterService); + } + + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + TimeValue timeout, + ActionListener listener + ) { + var actionCreator = new LlamaActionCreator(getSender(), getServiceComponents()); + if (model instanceof LlamaModel llamaModel) { + llamaModel.accept(actionCreator).execute(inputs, timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + protected void validateInputType(InputType inputType, Model model, ValidationException validationException) { + ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException); + } + + /** + * Creates a LlamaModel based on the provided parameters. + * + * @param inferenceId the unique identifier for the inference entity + * @param taskType the type of task this model is designed for + * @param serviceSettings the settings for the inference service + * @param chunkingSettings the settings for chunking, if applicable + * @param secretSettings the secret settings for the model, such as API keys or tokens + * @param failureMessage the message to use in case of failure + * @param context the context for parsing configuration settings + * @return a new instance of LlamaModel based on the provided parameters + */ + protected LlamaModel createModel( + String inferenceId, + TaskType taskType, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + switch (taskType) { + case TEXT_EMBEDDING: + return new LlamaEmbeddingsModel(inferenceId, taskType, NAME, serviceSettings, chunkingSettings, secretSettings, context); + case CHAT_COMPLETION, COMPLETION: + return new LlamaChatCompletionModel(inferenceId, taskType, NAME, serviceSettings, secretSettings, context); + default: + throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + } + } + + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof LlamaEmbeddingsModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; + + var updatedServiceSettings = new LlamaEmbeddingsServiceSettings( + serviceSettings.modelId(), + serviceSettings.uri(), + embeddingSize, + similarityToUse, + serviceSettings.maxInputTokens(), + serviceSettings.rateLimitSettings() + ); + + return new LlamaEmbeddingsModel(embeddingsModel, updatedServiceSettings); + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } + } + + @Override + protected void doChunkedInfer( + Model model, + List inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + if (model instanceof LlamaEmbeddingsModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var llamaModel = (LlamaEmbeddingsModel) model; + var actionCreator = new LlamaActionCreator(getSender(), getServiceComponents()); + + List batchedRequests = new EmbeddingRequestChunker<>( + inputs, + EMBEDDING_MAX_BATCH_SIZE, + llamaModel.getConfigurations().getChunkingSettings() + ).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = llamaModel.accept(actionCreator); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + } + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof LlamaChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + var llamaChatCompletionModel = (LlamaChatCompletionModel) model; + var overriddenModel = LlamaChatCompletionModel.of(llamaChatCompletionModel, inputs.getRequest()); + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + overriddenModel, + UNIFIED_CHAT_COMPLETION_HANDLER, + unifiedChatInput -> new LlamaChatCompletionRequest(unifiedChatInput, overriddenModel), + UnifiedChatInput.class + ); + var errorMessage = LlamaActionCreator.buildErrorMessage(CHAT_COMPLETION, model.getInferenceEntityId()); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + + action.execute(inputs, timeout, listener); + } + + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(COMPLETION, CHAT_COMPLETION); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap( + removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS) + ); + } + + LlamaModel model = createModel( + modelId, + taskType, + serviceSettingsMap, + chunkingSettings, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public Model parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + chunkingSettings, + secretSettingsMap, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + private LlamaModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + ChunkingSettings chunkingSettings, + Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + chunkingSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + ChunkingSettings chunkingSettings = null; + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); + } + + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + chunkingSettings, + null, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return ML_INFERENCE_LLAMA_ADDED; + } + + /** + * Configuration class for the Llama inference service. + * It provides the settings and configurations required for the service. + */ + public static class Configuration { + public static InferenceServiceConfiguration get() { + return CONFIGURATION.getOrCompute(); + } + + private Configuration() {} + + private static final LazyInitializable CONFIGURATION = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + URL, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.") + .setLabel("URL") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription( + "Refer to the Llama models documentation for the list of available models." + ) + .setLabel("Model") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java new file mode 100644 index 0000000000000..f647338ba3110 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreator.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.action; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModel; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsResponseHandler; +import org.elasticsearch.xpack.inference.services.llama.request.completion.LlamaChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.llama.request.embeddings.LlamaEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; + +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +/** + * Creates actions for Llama inference requests, handling both embeddings and completions. + * This class implements the {@link LlamaActionVisitor} interface to provide specific action creation methods. + */ +public class LlamaActionCreator implements LlamaActionVisitor { + + private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = "Failed to send Llama %s request from inference entity id [%s]"; + private static final String COMPLETION_ERROR_PREFIX = "Llama completions"; + private static final String USER_ROLE = "user"; + + private static final ResponseHandler EMBEDDINGS_HANDLER = new LlamaEmbeddingsResponseHandler( + "llama text embedding", + HuggingFaceEmbeddingsResponseEntity::fromResponse + ); + private static final ResponseHandler COMPLETION_HANDLER = new LlamaCompletionResponseHandler( + "llama completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); + + private final Sender sender; + private final ServiceComponents serviceComponents; + + /** + * Constructs a new LlamaActionCreator with the specified sender and service components. + * + * @param sender the sender to use for executing actions + * @param serviceComponents the service components providing necessary services + */ + public LlamaActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(LlamaEmbeddingsModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + EMBEDDINGS_HANDLER, + embeddingsInput -> new LlamaEmbeddingsRequest( + serviceComponents.truncator(), + truncate(embeddingsInput.getInputs(), model.getServiceSettings().maxInputTokens()), + model + ), + EmbeddingsInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId()); + return new SenderExecutableAction(sender, manager, errorMessage); + } + + @Override + public ExecutableAction create(LlamaChatCompletionModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + COMPLETION_HANDLER, + inputs -> new LlamaChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId()); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX); + } + + /** + * Builds an error message for failed requests. + * + * @param requestType the type of request that failed + * @param inferenceId the inference entity ID associated with the request + * @return a formatted error message + */ + public static String buildErrorMessage(TaskType requestType, String inferenceId) { + return format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, requestType.toString(), inferenceId); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java index 391a549df924a..ea31435780b96 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java @@ -61,7 +61,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); + List docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index b11feb117d761..2a8d186d80e99 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -146,7 +147,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -156,14 +157,14 @@ protected void doChunkedInfer( if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) { List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, MistralConstants.MAX_BATCH_SIZE, mistralEmbeddingsModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index edff1dfc08cba..577c067ba730f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -325,7 +326,7 @@ public void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -340,14 +341,14 @@ protected void doChunkedInfer( var actionCreator = new OpenAiActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, EMBEDDING_MAX_BATCH_SIZE, openAiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = openAiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 0ffec057dc2b4..f05dd256cfd50 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -285,7 +286,7 @@ protected void validateInputType(InputType inputType, Model model, ValidationExc @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, @@ -300,14 +301,14 @@ protected void doChunkedInfer( var actionCreator = new VoyageAIActionCreator(getSender(), getServiceComponents()); List batchedRequests = new EmbeddingRequestChunker<>( - inputs.getInputs(), + inputs, getBatchSize(voyageaiModel), voyageaiModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = voyageaiModel.accept(actionCreator, taskSettings); - action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); + action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java index 03753835177cb..5bf9bd66def2f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java @@ -57,7 +57,7 @@ public ExecutableAction create(VoyageAIEmbeddingsModel model, Map new VoyageAIEmbeddingsRequest( - embeddingsInput.getStringInputs(), + embeddingsInput.getInputs(), embeddingsInput.getInputType(), overriddenModel ), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java index 0bfa640b0cded..2682f74d0ebf9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; @@ -65,8 +64,6 @@ public void testOneInputIsValid() { public void testMoreThanOneInput() { var badInput = mock(EmbeddingsInput.class); - var input = List.of(new ChunkInferenceInput("one"), new ChunkInferenceInput("two")); - when(badInput.getInputs()).thenReturn(input); when(badInput.isSingleInput()).thenReturn(false); var actualException = new AtomicReference(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java new file mode 100644 index 0000000000000..d6ba10b1932dc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInputTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.is; + +public class EmbeddingsInputTests extends ESTestCase { + public void testCallingGetInputs_invokesSupplier() { + AtomicBoolean invoked = new AtomicBoolean(); + final List list = List.of("input1", "input2"); + Supplier> supplier = () -> { + invoked.set(true); + return list; + }; + EmbeddingsInput input = new EmbeddingsInput(supplier, null); + // Ensure we don't invoke the supplier until we call getInputs() + assertThat(invoked.get(), is(false)); + + assertThat(input.getInputs(), is(list)); + assertThat(invoked.get(), is(true)); + } + + public void testCallingGetInputsTwice_throws() { + Supplier> supplier = () -> List.of("input"); + EmbeddingsInput input = new EmbeddingsInput(supplier, null); + input.getInputs(); + var exception = expectThrows(AssertionError.class, input::getInputs); + assertThat(exception.getMessage(), is("EmbeddingsInput supplier invoked twice")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index d8c8095879b55..ee633f87cd3b4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -143,7 +142,7 @@ public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception PlainActionFuture listener = new PlainActionFuture<>(); sender.send( OpenAiEmbeddingsRequestManagerTests.makeCreator(getUrl(webServer), null, "key", "model", null, threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null), + new EmbeddingsInput(List.of("abc"), null), null, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java index e91b0b3451a77..9e7215ecc0a94 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java @@ -17,7 +17,7 @@ public class InferenceInputsTests extends ESTestCase { public void testCastToSucceeds() { - InferenceInputs inputs = new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull(), false); + InferenceInputs inputs = new EmbeddingsInput(List.of(), InputTypeTests.randomWithNull()); assertThat(inputs.castTo(EmbeddingsInput.class), Matchers.instanceOf(EmbeddingsInput.class)); var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null); @@ -29,7 +29,7 @@ public void testCastToSucceeds() { } public void testCastToFails() { - InferenceInputs inputs = new EmbeddingsInput(List.of(), null, false); + InferenceInputs inputs = new EmbeddingsInput(List.of(), null); var exception = expectThrows(IllegalArgumentException.class, () -> inputs.castTo(QueryAndDocsInputs.class)); assertThat( exception.getMessage(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java index d40ee517a1c51..a232b7724ca98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.Scheduler; @@ -62,7 +61,7 @@ public void testExecuting_DoesNotCallOnFailureForTimeout_AfterIllegalArgumentExc var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), mockThreadPool, listener @@ -82,7 +81,7 @@ public void testRequest_ReturnsTimeoutException() { PlainActionFuture listener = new PlainActionFuture<>(); var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -106,7 +105,7 @@ public void testRequest_DoesNotCallOnFailureTwiceWhenTimingOut() throws Exceptio var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -135,7 +134,7 @@ public void testRequest_DoesNotCallOnResponseAfterTimingOut() throws Exception { var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -162,7 +161,7 @@ public void testRequest_DoesNotCallOnFailureForTimeout_AfterAlreadyCallingOnResp var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), mockThreadPool, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 5d7a6a149f941..659b935aff8d8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; @@ -20,7 +21,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -131,7 +131,7 @@ protected void doUnifiedCompletionInfer( @Override protected void doChunkedInfer( Model model, - EmbeddingsInput inputs, + List inputs, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchCompletionRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchCompletionRequestManagerTests.java new file mode 100644 index 0000000000000..d054997be7d3b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchCompletionRequestManagerTests.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModelTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.request.completion.AlibabaCloudSearchCompletionRequest; +import org.mockito.ArgumentCaptor; + +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.rest.RestStatus.BAD_REQUEST; +import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class AlibabaCloudSearchCompletionRequestManagerTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + + public void testExecute_executesRequest() { + var inputs = new ChatCompletionInput(List.of("input1", "input2", "input3")); + RequestSender mockSender = mock(RequestSender.class); + PlainActionFuture listener = new PlainActionFuture<>(); + + ExecutorService mockExecutorService = mock(ExecutorService.class); + var requestManager = createRequestManagerWithMockExecutor(mockExecutorService); + requestManager.execute(inputs, mockSender, () -> false, listener); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ExecutableInferenceRequest.class); + verify(mockExecutorService).execute(captor.capture()); + + ExecutableInferenceRequest executableRequest = captor.getValue(); + assertThat(executableRequest.request(), is(instanceOf(AlibabaCloudSearchCompletionRequest.class))); + assertThat(executableRequest.responseHandler().getRequestType(), is("alibaba cloud search completion")); + } + + public void testExecute_throwsElasticsearchStatusException_whenNumberOfInputsIsEven() { + var inputs = new ChatCompletionInput(List.of("input1", "input2")); + RequestSender mockSender = mock(RequestSender.class); + PlainActionFuture listener = new PlainActionFuture<>(); + + var requestManager = createRequestManagerWithMockExecutor(mock(ExecutorService.class)); + requestManager.execute(inputs, mockSender, () -> false, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.status(), is(BAD_REQUEST)); + assertThat(thrownException.getMessage(), containsString("Alibaba Completion's inputs must be an odd number")); + } + + public void testExecute_throwsIllegalArgumentException_whenInputIsNotChatCompletion() { + var inputs = new EmbeddingsInput(List.of("input1"), InputType.SEARCH); + RequestSender mockSender = mock(RequestSender.class); + PlainActionFuture listener = new PlainActionFuture<>(); + + var requestManager = createRequestManagerWithMockExecutor(mock(ExecutorService.class)); + var thrownException = expectThrows( + IllegalArgumentException.class, + () -> requestManager.execute(inputs, mockSender, () -> false, listener) + ); + + assertThat(thrownException.getMessage(), containsString("Unable to convert inference inputs type")); + } + + private AlibabaCloudSearchCompletionRequestManager createRequestManagerWithMockExecutor(ExecutorService mockExecutorService) { + ThreadPool mockThreadPool = mock(ThreadPool.class); + when(mockThreadPool.executor(anyString())).thenReturn(mockExecutorService); + + var model = AlibabaCloudSearchCompletionModelTests.createModel( + "completion_test", + TaskType.COMPLETION, + AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("completion_test", "host", "default"), + AlibabaCloudSearchCompletionTaskSettingsTests.getTaskSettingsMap(null), + getSecretSettingsMap("secret") + ); + var account = new AlibabaCloudSearchAccount(model.getSecretSettings().apiKey()); + return AlibabaCloudSearchCompletionRequestManager.of(account, model, mockThreadPool); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java new file mode 100644 index 0000000000000..b09fbf43a8ca4 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchActionCreatorTests.java @@ -0,0 +1,337 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.alibabacloudsearch.action; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.WeightedToken; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.request.AlibabaCloudSearchUtils; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.rerank.AlibabaCloudSearchRerankModel; +import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.core.inference.results.RankedDocsResultsTests.buildExpectationRerank; +import static org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests.buildExpectationSparseEmbeddings; +import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class AlibabaCloudSearchActionCreatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private Sender sender; + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws IOException { + sender = mock(Sender.class); + webServer.start(); + threadPool = createThreadPool(inferenceUtilityExecutors()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_withTextEmbeddingsAction_Success() { + float[] values = { 0.1111111f, 0.2222222f, 0.3333333f }; + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new TextEmbeddingFloatResults(List.of(new TextEmbeddingFloatResults.Embedding(values)))); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + var action = createTextEmbeddingsAction(); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat(result.asMap(), is(buildExpectationFloat(List.of(values)))); + } + + public void testExecute_withTextEmbeddingsAction_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() { + doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any()); + var action = createTextEmbeddingsAction(); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("error")); + } + + public void testExecute_withTextEmbeddingsAction_ListenerThrowsInternalServerError_WhenSenderThrowsException() { + doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any()); + var action = createTextEmbeddingsAction(); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("Failed to send AlibabaCloud Search text embeddings request. Cause: error")); + } + + public void testExecute_withSparseEmbeddingsAction_Success() { + String token = "token"; + float weight = 0.1111111f; + boolean isTruncated = false; + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse( + new SparseEmbeddingResults( + List.of(new SparseEmbeddingResults.Embedding(List.of(new WeightedToken(token, weight)), isTruncated)) + ) + ); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + var action = createSparseEmbeddingsAction(); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat( + result.asMap(), + is( + buildExpectationSparseEmbeddings( + List.of(new SparseEmbeddingResultsTests.EmbeddingExpectation(Map.of(token, weight), isTruncated)) + ) + ) + ); + } + + public void testExecute_withSparseEmbeddingsAction_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() { + doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any()); + var action = createSparseEmbeddingsAction(); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("error")); + } + + public void testExecute_withSparseEmbeddingsAction_ListenerThrowsInternalServerError_WhenSenderThrowsException() { + doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any()); + var action = createSparseEmbeddingsAction(); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("Failed to send AlibabaCloud Search sparse embeddings request. Cause: error")); + } + + public void testExecute_withRerankAction_Success() { + int index = 0; + float relevanceScore = 0.1111111f; + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new RankedDocsResults(List.of(new RankedDocsResults.RankedDoc(index, relevanceScore, null)))); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + var action = createRerankAction(); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs("query", List.of(randomAlphaOfLength(10))), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + assertThat( + result.asMap(), + is( + buildExpectationRerank( + List.of(new RankedDocsResultsTests.RerankExpectation(Map.of("index", index, "relevance_score", relevanceScore))) + ) + ) + ); + } + + public void testExecute_withRerankAction_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() { + doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any()); + var action = createRerankAction(); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs("query", List.of(randomAlphaOfLength(10))), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("error")); + } + + public void testExecute_withRerankAction_ListenerThrowsInternalServerError_WhenSenderThrowsException() { + doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any()); + var action = createRerankAction(); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new QueryAndDocsInputs("query", List.of(randomAlphaOfLength(10))), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("Failed to send AlibabaCloud Search rerank request. Cause: error")); + } + + public void testExecute_withCompletionAction_Success() { + var resultString = randomAlphaOfLength(100); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(new ChatCompletionResults(List.of(new ChatCompletionResults.Result(resultString)))); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + var action = createCompletionAction(); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of(resultString)))); + } + + public void testExecute_withCompletionAction_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() { + doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any()); + var action = createCompletionAction(); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("error")); + } + + public void testExecute_withCompletionAction_ListenerThrowsInternalServerError_WhenSenderThrowsException() { + doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any()); + var action = createCompletionAction(); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(thrownException.getMessage(), is("Failed to send AlibabaCloud Search completion request. Cause: error")); + } + + private ExecutableAction createTextEmbeddingsAction() { + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + AlibabaCloudSearchEmbeddingsModel embeddingsModel = new AlibabaCloudSearchEmbeddingsModel( + "text_embedding_test", + TaskType.TEXT_EMBEDDING, + AlibabaCloudSearchUtils.SERVICE_NAME, + AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("text_embedding_test", "host", "default"), + null, + null, + DefaultSecretSettingsTests.getSecretSettingsMap("secret"), + null + ); + var actionCreator = new AlibabaCloudSearchActionCreator(sender, serviceComponents); + return actionCreator.create(embeddingsModel, Map.of()); + } + + private ExecutableAction createSparseEmbeddingsAction() { + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + AlibabaCloudSearchSparseModel sparseModel = new AlibabaCloudSearchSparseModel( + "sparse_embedding_test", + TaskType.SPARSE_EMBEDDING, + AlibabaCloudSearchUtils.SERVICE_NAME, + AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("sparse_embedding_test", "host", "default"), + null, + null, + DefaultSecretSettingsTests.getSecretSettingsMap("secret"), + null + ); + var actionCreator = new AlibabaCloudSearchActionCreator(sender, serviceComponents); + return actionCreator.create(sparseModel, Map.of()); + } + + private ExecutableAction createRerankAction() { + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + AlibabaCloudSearchRerankModel rerankModel = new AlibabaCloudSearchRerankModel( + "rerank_test", + TaskType.RERANK, + AlibabaCloudSearchUtils.SERVICE_NAME, + AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("rerank_test", "host", "default"), + null, + DefaultSecretSettingsTests.getSecretSettingsMap("secret"), + null + ); + var actionCreator = new AlibabaCloudSearchActionCreator(sender, serviceComponents); + return actionCreator.create(rerankModel, Map.of()); + } + + private ExecutableAction createCompletionAction() { + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + AlibabaCloudSearchCompletionModel completionModel = new AlibabaCloudSearchCompletionModel( + "completion_test", + TaskType.COMPLETION, + AlibabaCloudSearchUtils.SERVICE_NAME, + AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("completion_test", "host", "default"), + null, + DefaultSecretSettingsTests.getSecretSettingsMap("secret"), + null + ); + var actionCreator = new AlibabaCloudSearchActionCreator(sender, serviceComponents); + return actionCreator.create(completionModel, Map.of()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java index 8b55d5b78f397..145a2e6078360 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java @@ -74,7 +74,7 @@ public void testEmbeddingsRequestAction_Titan() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -112,7 +112,7 @@ public void testEmbeddingsRequestAction_Cohere() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); @@ -145,7 +145,7 @@ public void testEmbeddingsRequestAction_HandlesException() throws IOException { ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(sender.sendCount(), is(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java index 797d50878a0b7..aa19e9ae07d48 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java @@ -12,7 +12,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; @@ -83,10 +82,10 @@ public void send( ActionListener listener ) { sendCounter++; - if (inferenceInputs instanceof EmbeddingsInput docsInput) { - inputs.add(ChunkInferenceInput.inputs(docsInput.getInputs())); - if (docsInput.getInputType() != null) { - inputTypes.add(docsInput.getInputType()); + if (inferenceInputs instanceof EmbeddingsInput embeddingsInput) { + inputs.add(embeddingsInput.getInputs()); + if (embeddingsInput.getInputType() != null) { + inputTypes.add(embeddingsInput.getInputType()); } } else if (inferenceInputs instanceof ChatCompletionInput chatCompletionInput) { inputs.add(chatCompletionInput.getInputs()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java index 8a6a6a9c5eddf..4fbbfdd40ee30 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -112,7 +111,7 @@ public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws threadPool, new TimeValue(30, TimeUnit.SECONDS) ); - sender.send(requestManager, new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null), null, listener); + sender.send(requestManager, new EmbeddingsInput(List.of("abc"), null), null, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.456F, 0.678F, 0.789F })))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java index 9896286f503f3..afa56be0e58b3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java @@ -111,11 +111,11 @@ public void testEmbeddingsRequestAction() throws IOException { ); model.setURI(getUrl(webServer)); - var creator = new AzureAiStudioActionCreator(sender, serviceComponents); - var action = creator.create(model, Map.of()); - PlainActionFuture listener = new PlainActionFuture<>(); - var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + final var creator = new AzureAiStudioActionCreator(sender, serviceComponents); + final var action = creator.create(model, Map.of()); + final PlainActionFuture listener = new PlainActionFuture<>(); + final var inputType = InputTypeTests.randomSearchAndIngestWithNull(); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java index 5f3bbd5af0a16..5287907a2ce76 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java @@ -119,7 +119,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -170,7 +170,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOExcepti var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -222,7 +222,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var failureCauseMessage = "Required [data]"; var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -296,7 +296,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -373,7 +373,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -433,11 +433,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute( - new EmbeddingsInput(List.of("super long input"), null, inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("super long input"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java index 7d59ea225ab22..f9c85af00d4da 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -117,11 +116,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -149,11 +144,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -174,11 +165,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -199,11 +186,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -218,11 +201,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java index 6438a328f9fcf..9dfd05f13cb6e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java @@ -119,7 +119,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java index 05d69bae4903e..2d87dc3686040 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java @@ -124,7 +124,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); InputType inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -207,7 +207,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I ); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -262,7 +262,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -282,7 +282,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -302,7 +302,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -316,7 +316,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -330,7 +330,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java index 5e6ee1032ad5f..6dc33935ac617 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java @@ -77,7 +77,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() { var listener = new PlainActionFuture(); var manager = CustomRequestManager.of(model, threadPool); - manager.execute(new EmbeddingsInput(List.of("abc", "123"), null, null), mock(RequestSender.class), () -> false, listener); + manager.execute(new EmbeddingsInput(List.of("abc", "123"), null), mock(RequestSender.class), () -> false, listener); var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TimeValue.timeValueSeconds(30))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index 106b55159b1dc..9c0ab64dbd996 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -82,10 +82,7 @@ public void testCreateRequest() throws IOException { ); var request = new CustomRequest( - EmbeddingParameters.of( - new EmbeddingsInput(List.of("abc", "123"), null, null), - model.getServiceSettings().getInputTypeTranslator() - ), + EmbeddingParameters.of(new EmbeddingsInput(List.of("abc", "123"), null), model.getServiceSettings().getInputTypeTranslator()), model ); var httpRequest = request.createHttpRequest(); @@ -146,7 +143,7 @@ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() throws IOEx var request = new CustomRequest( EmbeddingParameters.of( - new EmbeddingsInput(List.of("abc", "123"), null, InputType.INGEST), + new EmbeddingsInput(List.of("abc", "123"), InputType.INGEST), model.getServiceSettings().getInputTypeTranslator() ), model @@ -207,7 +204,7 @@ public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws var request = new CustomRequest( EmbeddingParameters.of( - new EmbeddingsInput(List.of("abc", "123"), null, InputType.SEARCH), + new EmbeddingsInput(List.of("abc", "123"), InputType.SEARCH), model.getServiceSettings().getInputTypeTranslator() ), model diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java index 5231802ef4b92..d6dc03c9daf75 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java @@ -21,7 +21,7 @@ public class EmbeddingParametersTests extends ESTestCase { public void testTaskTypeParameters_UsesDefaultValue() { var parameters = EmbeddingParameters.of( - new EmbeddingsInput(List.of("input"), null, InputType.INGEST), + new EmbeddingsInput(List.of("input"), InputType.INGEST), new InputTypeTranslator(Map.of(), "default") ); @@ -30,7 +30,7 @@ public void testTaskTypeParameters_UsesDefaultValue() { public void testTaskTypeParameters_UsesMappedValue() { var parameters = EmbeddingParameters.of( - new EmbeddingsInput(List.of("input"), null, InputType.INGEST), + new EmbeddingsInput(List.of("input"), InputType.INGEST), new InputTypeTranslator(Map.of(InputType.INGEST, "ingest_value"), "default") ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java index 608fdb4d314c3..2eb61b50574ee 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java @@ -64,7 +64,7 @@ public void testFromTextEmbeddingResponse() throws IOException { new TextEmbeddingResponseParser("$.result.embeddings[*].embedding") ); var request = new CustomRequest( - EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()), + EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null), model.getServiceSettings().getInputTypeTranslator()), model ); InferenceServiceResults results = CustomResponseEntity.fromResponse( @@ -115,7 +115,7 @@ public void testFromSparseEmbeddingResponse() throws IOException { ) ); var request = new CustomRequest( - EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()), + EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null), model.getServiceSettings().getInputTypeTranslator()), model ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index c8701b47a20b5..ebcc9e6db20e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -102,7 +102,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -163,7 +163,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -291,7 +291,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction() PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world", "second text"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world", "second text"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -347,7 +347,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_W PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("search query"), null, InputType.SEARCH), + new EmbeddingsInput(List.of("search query"), InputType.SEARCH), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -402,7 +402,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForDenseTextEmbeddingsAction PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -443,7 +443,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForDenseTextEmbeddingsAction_E var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of(), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of(), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -492,7 +492,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -556,7 +556,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java index 2668f7f8f7c27..f2077998d1797 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java @@ -108,7 +108,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of(input), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of(input), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -163,7 +163,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -187,7 +187,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -205,7 +205,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java index 5acd78930637b..8c6e5d31f59c4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java @@ -75,7 +75,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -99,7 +99,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -117,7 +117,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java index 760858b5a1261..0e575ed045711 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java @@ -71,7 +71,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -91,7 +91,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -105,7 +105,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index b14cfcd14ec43..62d7099b327d9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -101,7 +101,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -172,7 +172,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -224,7 +224,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -286,7 +286,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -409,7 +409,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -471,7 +471,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("123456"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("123456"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java index cfa0f0bb2198b..f3bedf04e056f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java @@ -63,7 +63,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderThrows() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -87,7 +87,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -108,7 +108,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java index 9376e4da76261..d9f3ed0c394db 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java @@ -114,7 +114,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(input), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(input), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -144,7 +144,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -173,7 +173,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -197,7 +197,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java new file mode 100644 index 0000000000000..709a4a0630ba0 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/action/LlamaActionCreatorTests.java @@ -0,0 +1,283 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.llama.action; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; +import org.elasticsearch.xpack.inference.InputTypeTests; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.llama.completion.LlamaChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.llama.embeddings.LlamaEmbeddingsModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class LlamaActionCreatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityExecutors()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "embeddings": [ + [ + -0.0123, + 0.123 + ] + ] + { + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool)); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); + + assertEmbeddingsRequest(); + } + } + + public void testExecute_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + [ + { + "embeddings": [ + [ + -0.0123, + 0.123 + ] + ] + { + ] + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createEmbeddingsFuture(sender, createWithEmptySettings(threadPool)); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to parse object: expecting token of type [START_ARRAY] but found [START_OBJECT]") + ); + + assertEmbeddingsRequest(); + } + } + + public void testExecute_ReturnsSuccessfulResponse_ForCompletionAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "chatcmpl-03e70a75-efb6-447d-b661-e5ed0bd59ce9", + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "Hello there, how may I assist you today?", + "refusal": null, + "role": "assistant", + "annotations": null, + "audio": null, + "function_call": null, + "tool_calls": null + } + } + ], + "created": 1750157476, + "model": "llama3.2:3b", + "object": "chat.completion", + "service_tier": null, + "system_fingerprint": "fp_ollama", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 30, + "total_tokens": 40, + "completion_tokens_details": null, + "prompt_tokens_details": null + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createCompletionFuture(sender, createWithEmptySettings(threadPool)); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); + + assertCompletionRequest(); + } + } + + public void testExecute_FailsFromInvalidResponseFormat_ForCompletionAction() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "invalid_field": "unexpected" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createCompletionFuture( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) + ); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to send Llama completion request from inference entity id [id]. Cause: Required [choices]") + ); + + assertCompletionRequest(); + } + } + + private PlainActionFuture createEmbeddingsFuture(Sender sender, ServiceComponents threadPool) { + var model = LlamaEmbeddingsModelTests.createEmbeddingsModel("model", getUrl(webServer), "secret"); + var actionCreator = new LlamaActionCreator(sender, threadPool); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute( + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + return listener; + } + + private PlainActionFuture createCompletionFuture(Sender sender, ServiceComponents threadPool) { + var model = LlamaChatCompletionModelTests.createCompletionModel("model", getUrl(webServer), "secret"); + var actionCreator = new LlamaActionCreator(sender, threadPool); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + return listener; + } + + private void assertCompletionRequest() throws IOException { + assertCommonRequestProperties(); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } + + @SuppressWarnings("unchecked") + private void assertEmbeddingsRequest() throws IOException { + assertCommonRequestProperties(); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("contents"), instanceOf(List.class)); + var inputList = (List) requestMap.get("contents"); + assertThat(inputList, contains("abc")); + } + + private void assertCommonRequestProperties() { + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java index d7c72cf98e267..102cdbec77d74 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java @@ -109,7 +109,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -166,7 +166,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -222,7 +222,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -285,7 +285,7 @@ public void testCreate_OpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() th PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -625,7 +625,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -712,7 +712,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -784,7 +784,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("super long input"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("super long input"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java index 17c08dee34e5c..4a1609c7a27df 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; @@ -115,7 +114,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -155,7 +154,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -179,7 +178,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -203,7 +202,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -221,7 +220,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -239,7 +238,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java index 7e89b5c3497c9..2f51ec22c791c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -111,11 +110,7 @@ public void testCreate_VoyageAIEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java index a4310412514cd..ff3f23b2d9027 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -121,11 +120,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -222,11 +217,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -323,11 +314,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForBinaryResponseType() throws PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); + action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -396,7 +383,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -420,7 +407,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -438,7 +425,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -461,7 +448,7 @@ private ExecutableAction createAction( threadPool, model, EMBEDDINGS_HANDLER, - (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getStringInputs(), embeddingsInput.getInputType(), model), + (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getInputs(), embeddingsInput.getInputType(), model), EmbeddingsInput.class );