From e38621bee43991906c4ba7b32458458e5ac97909 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Wed, 19 Feb 2025 15:21:19 -0500 Subject: [PATCH 1/8] Adding validation to ElasticsearchInternalService --- .../inference/InferenceService.java | 4 +- ...ransportDeleteInferenceEndpointAction.java | 4 +- .../TransportPutInferenceModelAction.java | 45 +- .../inference/services/ServiceUtils.java | 51 --- .../AlibabaCloudSearchService.java | 14 - .../amazonbedrock/AmazonBedrockService.java | 14 - .../services/anthropic/AnthropicService.java | 7 - .../azureaistudio/AzureAiStudioService.java | 7 - .../azureopenai/AzureOpenAiService.java | 14 - .../services/cohere/CohereService.java | 14 - .../elastic/ElasticInferenceService.java | 7 - .../BaseElasticsearchInternalService.java | 7 +- .../ElasticsearchInternalService.java | 26 -- .../googleaistudio/GoogleAiStudioService.java | 7 - .../googlevertexai/GoogleVertexAiService.java | 7 - .../huggingface/HuggingFaceService.java | 7 - .../ibmwatsonx/IbmWatsonxService.java | 7 - .../services/jinaai/JinaAIService.java | 13 - .../services/mistral/MistralService.java | 7 - .../services/openai/OpenAiService.java | 14 - .../ChatCompletionModelValidator.java | 5 +- ...icsearchInternalServiceModelValidator.java | 30 ++ .../services/validation/ModelValidator.java | 3 +- .../validation/ModelValidatorBuilder.java | 11 +- .../ServiceIntegrationValidator.java | 3 +- ...CompletionServiceIntegrationValidator.java | 41 +- .../validation/SimpleModelValidator.java | 11 +- .../SimpleServiceIntegrationValidator.java | 6 +- .../TextEmbeddingModelValidator.java | 5 +- .../inference/services/ServiceUtilsTests.java | 109 ----- .../AlibabaCloudSearchServiceTests.java | 69 ---- .../AmazonBedrockServiceTests.java | 227 ---------- .../AzureAiStudioServiceTests.java | 133 ------ .../azureopenai/AzureOpenAiServiceTests.java | 351 ---------------- .../services/cohere/CohereServiceTests.java | 192 --------- .../ElasticsearchInternalServiceTests.java | 61 --- .../GoogleAiStudioServiceTests.java | 126 ------ .../huggingface/HuggingFaceServiceTests.java | 84 ---- .../ibmwatsonx/IbmWatsonxServiceTests.java | 163 -------- .../services/jinaai/JinaAIServiceTests.java | 172 -------- .../services/mistral/MistralServiceTests.java | 29 -- .../services/openai/OpenAiServiceTests.java | 389 ------------------ .../ChatCompletionModelValidatorTests.java | 18 +- .../ModelValidatorBuilderTests.java | 4 +- ...etionServiceIntegrationValidatorTests.java | 27 +- .../validation/SimpleModelValidatorTests.java | 16 +- ...impleServiceIntegrationValidatorTests.java | 22 +- .../TextEmbeddingModelValidatorTests.java | 18 +- 48 files changed, 155 insertions(+), 2446 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index e1ebd8bb81ff4..708a48dd39908 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -158,10 +158,10 @@ void chunkedInfer( /** * Stop the model deployment. * The default action does nothing except acknowledge the request (true). - * @param unparsedModel The unparsed model configuration + * @param model The model configuration * @param listener The listener */ - default void stop(UnparsedModel unparsedModel, ActionListener listener) { + default void stop(Model model, ActionListener listener) { listener.onResponse(true); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index 3b6901ae0c31d..28a400c009af4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -125,7 +125,9 @@ private void doExecuteForked( var service = serviceRegistry.getService(unparsedModel.service()); if (service.isPresent()) { - service.get().stop(unparsedModel, listener); + var model = service.get() + .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); + service.get().stop(model, listener); } else { listener.onFailure( new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 73af12dacfadf..16b049c26ee5d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -46,6 +46,7 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; +import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.io.IOException; import java.util.List; @@ -192,19 +193,23 @@ private void parseAndStoreModel( ActionListener storeModelListener = listener.delegateFailureAndWrap( (delegate, verifiedModel) -> modelRegistry.storeModel( verifiedModel, - ActionListener.wrap(r -> startInferenceEndpoint(service, timeout, verifiedModel, delegate), e -> { - if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) { - delegate.onFailure( - new ElasticsearchStatusException( - "One or more nodes in your cluster does not support chunking_settings. " - + "Please update all nodes in your cluster to the latest version to use chunking_settings.", - RestStatus.BAD_REQUEST - ) - ); - } else { - delegate.onFailure(e); + ActionListener.wrap( + r -> listener.onResponse(new PutInferenceModelAction.Response(verifiedModel.getConfigurations())), + e -> { + if (e.getCause() instanceof StrictDynamicMappingException + && e.getCause().getMessage().contains("chunking_settings")) { + delegate.onFailure( + new ElasticsearchStatusException( + "One or more nodes in your cluster does not support chunking_settings. " + + "Please update all nodes in your cluster to the latest version to use chunking_settings.", + RestStatus.BAD_REQUEST + ) + ); + } else { + delegate.onFailure(e); + } } - }) + ) ) ); @@ -212,26 +217,14 @@ private void parseAndStoreModel( if (skipValidationAndStart) { storeModelListener.onResponse(model); } else { - service.checkModelConfig(model, storeModelListener); + ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service instanceof ElasticsearchInternalService) + .validate(service, model, timeout, storeModelListener); } }); service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener); } - private void startInferenceEndpoint( - InferenceService service, - TimeValue timeout, - Model model, - ActionListener listener - ) { - if (skipValidationAndStart) { - listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations())); - } else { - service.start(model, timeout, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations()))); - } - } - private Map requestToMap(PutInferenceModelAction.Request request) throws IOException { try ( XContentParser parser = XContentHelper.createParser( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 13d641101a1cf..94fe594d19041 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -8,22 +8,16 @@ package org.elasticsearch.xpack.inference.services; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbedding; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; @@ -723,51 +717,6 @@ public static ElasticsearchStatusException createInvalidModelException(Model mod ); } - /** - * Evaluate the model and return the text embedding size - * @param model Should be a text embedding model - * @param service The inference service - * @param listener Size listener - */ - public static void getEmbeddingSize(Model model, InferenceService service, ActionListener listener) { - assert model.getTaskType() == TaskType.TEXT_EMBEDDING; - - service.infer( - model, - null, - List.of(TEST_EMBEDDING_INPUT), - false, - Map.of(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener.delegateFailureAndWrap((delegate, r) -> { - if (r instanceof TextEmbedding embeddingResults) { - try { - delegate.onResponse(embeddingResults.getFirstEmbeddingSize()); - } catch (Exception e) { - delegate.onFailure( - new ElasticsearchStatusException("Could not determine embedding size", RestStatus.BAD_REQUEST, e) - ); - } - } else { - delegate.onFailure( - new ElasticsearchStatusException( - "Could not determine embedding size. " - + "Expected a result of type [" - + InferenceTextEmbeddingFloatResults.NAME - + "] got [" - + r.getWriteableName() - + "]", - RestStatus.BAD_REQUEST - ) - ); - } - }) - ); - } - - private static final String TEST_EMBEDDING_INPUT = "how big"; - public static SecureString apiKey(@Nullable ApiKeySecrets secrets) { // To avoid a possible null pointer throughout the code we'll create a noop api key of an empty array return secrets == null ? new SecureString(new char[0]) : secrets.apiKey(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 589ca1e033f06..dee26b4633df4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -46,7 +46,6 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -326,19 +325,6 @@ private EmbeddingRequestChunker.EmbeddingType getEmbeddingTypeFromTaskType(TaskT }; } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof AlibabaCloudSearchEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 493acd3c0cd1a..79d94b60c93ce 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -45,7 +45,6 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.io.IOException; import java.util.EnumSet; @@ -302,19 +301,6 @@ public Set supportedStreamingTasks() { return COMPLETION_ONLY; } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof AmazonBedrockEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 64fe42fbbc171..59f2e0b637b2a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -36,7 +36,6 @@ import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -176,12 +175,6 @@ public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType ta ); } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index 34a5c2b4cc1e9..f0dd566b52096 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -45,7 +45,6 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -321,12 +320,6 @@ private AzureAiStudioModel createModelFromPersistent( ); } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof AzureAiStudioEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 9a77b63337978..8964d7bb8668a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -42,7 +42,6 @@ import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -294,19 +293,6 @@ protected void doChunkedInfer( } } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof AzureOpenAiEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 6c2d3bb96d74d..0955b9e591f47 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -44,7 +44,6 @@ import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -296,19 +295,6 @@ protected void doChunkedInfer( } } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof CohereEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 737c549255a71..b0ca9db73b058 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -54,7 +54,6 @@ import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.ArrayList; @@ -556,12 +555,6 @@ private ElasticInferenceServiceModel createModelFromPersistent( ); } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { var inputsAsList = DocumentsOnlyInput.of(inputs).getInputs(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index f743b94df3810..3cbdefbbe24ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -22,7 +22,6 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; @@ -118,9 +117,7 @@ public void start(Model model, TimeValue timeout, ActionListener finalL } @Override - public void stop(UnparsedModel unparsedModel, ActionListener listener) { - - var model = parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); + public void stop(Model model, ActionListener listener) { if (model instanceof ElasticsearchInternalModel esModel) { var serviceSettings = esModel.getServiceSettings(); @@ -297,7 +294,7 @@ protected void maybeStartDeployment( InferModelAction.Request request, ActionListener listener ) { - if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { this.start(model, request.getInferenceTimeout(), listener.delegateFailureAndWrap((l, started) -> { client.execute(InferModelAction.INSTANCE, request, listener); })); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index ddc5e3e1aa36c..a9840566c0ea4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -56,7 +56,6 @@ import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.ServiceUtils; import java.util.ArrayList; import java.util.Collections; @@ -536,31 +535,6 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M } } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - if (model instanceof CustomElandEmbeddingModel elandModel && elandModel.getTaskType() == TaskType.TEXT_EMBEDDING) { - // At this point the inference endpoint configuration has not been persisted yet, if we attempt to do inference using the - // inference id we'll get an error because the trained model code needs to use the persisted inference endpoint to retrieve the - // model id. To get around this we'll have the getEmbeddingSize() method use the model id instead of inference id. So we need - // to create a temporary model that overrides the inference id with the model id. - var temporaryModelWithModelId = new CustomElandEmbeddingModel( - elandModel.getServiceSettings().modelId(), - elandModel.getTaskType(), - elandModel.getConfigurations().getService(), - elandModel.getServiceSettings(), - elandModel.getConfigurations().getChunkingSettings() - ); - - ServiceUtils.getEmbeddingSize( - temporaryModelWithModelId, - this, - listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(elandModel, size))) - ); - } else { - listener.onResponse(model); - } - } - private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomElandEmbeddingModel model, int embeddingSize) { CustomElandInternalTextEmbeddingServiceSettings serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings( model.getServiceSettings().getNumAllocations(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 205cc545a23f0..c24532dd0187f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -46,7 +46,6 @@ import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -244,12 +243,6 @@ public Set supportedStreamingTasks() { return COMPLETION_ONLY; } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof GoogleAiStudioEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 3e921f669e864..a99655fc8c3c5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -41,7 +41,6 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -178,12 +177,6 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_15_0; } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 73c1446b9bb26..ef8a5ee0c9969 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -37,7 +37,6 @@ import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -83,12 +82,6 @@ protected HuggingFaceModel createModel( }; } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof HuggingFaceEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 3fa423c2dae19..3be9981bc5c91 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -42,7 +42,6 @@ import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -236,12 +235,6 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_16_0; } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof IbmWatsonxEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 37add1e264704..0be6dc0a53327 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -43,7 +43,6 @@ import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -278,18 +277,6 @@ protected void doChunkedInfer( } } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof JinaAIEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 3e40575e42faf..cd0aea84d5c1f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -42,7 +42,6 @@ import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -280,12 +279,6 @@ private MistralEmbeddingsModel createModelFromPersistent( ); } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof MistralEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 94312a39882fd..35d59f36109ac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -46,7 +46,6 @@ import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -332,19 +331,6 @@ protected void doChunkedInfer( } } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof OpenAiEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java index b7a9fa7e6f3ab..624f223c9f3e1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.validation; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; @@ -20,8 +21,8 @@ public ChatCompletionModelValidator(ServiceIntegrationValidator serviceIntegrati } @Override - public void validate(InferenceService service, Model model, ActionListener listener) { - serviceIntegrationValidator.validate(service, model, listener.delegateFailureAndWrap((delegate, r) -> { + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { + serviceIntegrationValidator.validate(service, model, timeout, listener.delegateFailureAndWrap((delegate, r) -> { delegate.onResponse(postValidate(service, model)); })); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java new file mode 100644 index 0000000000000..3c4abafb401fb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.validation; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.Model; + +public class ElasticsearchInternalServiceModelValidator implements ModelValidator { + + ModelValidator modelValidator; + + public ElasticsearchInternalServiceModelValidator(ModelValidator modelValidator) { + this.modelValidator = modelValidator; + } + + @Override + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { + modelValidator.validate(service, model, timeout, listener.delegateResponse((l, exception) -> { + // TODO: Cleanup the below code + service.stop(model, ActionListener.wrap((v) -> listener.onFailure(exception), (e) -> listener.onFailure(exception))); + })); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidator.java index c435939a17568..2c7701942ecb0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidator.java @@ -8,9 +8,10 @@ package org.elasticsearch.xpack.inference.services.validation; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; public interface ModelValidator { - void validate(InferenceService service, Model model, ActionListener listener); + void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java index 1c4306c4edd46..c0088858588ea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java @@ -11,7 +11,16 @@ import org.elasticsearch.inference.TaskType; public class ModelValidatorBuilder { - public static ModelValidator buildModelValidator(TaskType taskType) { + public static ModelValidator buildModelValidator(TaskType taskType, boolean isElasticsearchInternalService) { + var modelValidator = buildModelValidatorForTaskType(taskType); + if (isElasticsearchInternalService) { + return new ElasticsearchInternalServiceModelValidator(modelValidator); + } else { + return modelValidator; + } + } + + private static ModelValidator buildModelValidatorForTaskType(TaskType taskType) { if (taskType == null) { throw new IllegalArgumentException("Task type can't be null"); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java index 09fb43f584cf0..49ade6c00fb22 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java @@ -8,10 +8,11 @@ package org.elasticsearch.xpack.inference.services.validation; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; public interface ServiceIntegrationValidator { - void validate(InferenceService service, Model model, ActionListener listener); + void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java index 1092d84a6ef6b..872dc84f06b98 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java @@ -10,11 +10,11 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import java.util.List; @@ -28,32 +28,27 @@ public class SimpleChatCompletionServiceIntegrationValidator implements ServiceI private static final List TEST_INPUT = List.of("how big"); @Override - public void validate(InferenceService service, Model model, ActionListener listener) { + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { var chatCompletionInput = new UnifiedChatInput(TEST_INPUT, USER_ROLE, false); - service.unifiedCompletionInfer( - model, - chatCompletionInput.getRequest(), - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> { - if (r != null) { - listener.onResponse(r); - } else { - listener.onFailure( - new ElasticsearchStatusException( - "Could not complete inference endpoint creation as validation call to service returned null response.", - RestStatus.BAD_REQUEST - ) - ); - } - }, e -> { + service.unifiedCompletionInfer(model, chatCompletionInput.getRequest(), timeout, ActionListener.wrap(r -> { + if (r != null) { + listener.onResponse(r); + } else { listener.onFailure( new ElasticsearchStatusException( - "Could not complete inference endpoint creation as validation call to service threw an exception.", - RestStatus.BAD_REQUEST, - e + "Could not complete inference endpoint creation as validation call to service returned null response.", + RestStatus.BAD_REQUEST ) ); - }) - ); + } + }, e -> { + listener.onFailure( + new ElasticsearchStatusException( + "Could not complete inference endpoint creation as validation call to service threw an exception.", + RestStatus.BAD_REQUEST, + e + ) + ); + })); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java index f44cf61079369..3d592840f533b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.validation; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; @@ -20,11 +21,9 @@ public SimpleModelValidator(ServiceIntegrationValidator serviceIntegrationValida } @Override - public void validate(InferenceService service, Model model, ActionListener listener) { - serviceIntegrationValidator.validate( - service, - model, - listener.delegateFailureAndWrap((delegate, r) -> { delegate.onResponse(model); }) - ); + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { + serviceIntegrationValidator.validate(service, model, timeout, listener.delegateFailureAndWrap((delegate, r) -> { + delegate.onResponse(model); + })); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java index 70f01e77b9369..3905e6cb78c03 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java @@ -10,13 +10,13 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import java.util.List; import java.util.Map; @@ -26,7 +26,7 @@ public class SimpleServiceIntegrationValidator implements ServiceIntegrationVali private static final String QUERY = "test query"; @Override - public void validate(InferenceService service, Model model, ActionListener listener) { + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { service.infer( model, model.getTaskType().equals(TaskType.RERANK) ? QUERY : null, @@ -34,7 +34,7 @@ public void validate(InferenceService service, Model model, ActionListener { if (r != null) { listener.onResponse(r); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java index 1fe5c684196fe..6452a859bf458 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Strings; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; @@ -26,8 +27,8 @@ public TextEmbeddingModelValidator(ServiceIntegrationValidator serviceIntegratio } @Override - public void validate(InferenceService service, Model model, ActionListener listener) { - serviceIntegrationValidator.validate(service, model, listener.delegateFailureAndWrap((delegate, r) -> { + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { + serviceIntegrationValidator.validate(service, model, timeout, listener.delegateFailureAndWrap((delegate, r) -> { delegate.onResponse(postValidate(service, model, r)); })); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index e3df0f0b5a2e1..30e6a86bc1d96 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -8,22 +8,12 @@ package org.elasticsearch.xpack.inference.services; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; -import org.elasticsearch.xpack.inference.results.InferenceTextEmbeddingByteResultsTests; -import org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests; import java.util.EnumSet; import java.util.HashMap; @@ -41,23 +31,14 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveIntegerLessThanOrEqualToMax; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.getEmbeddingSize; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class ServiceUtilsTests extends ESTestCase { - - private static final TimeValue TIMEOUT = TimeValue.timeValueSeconds(30); - public void testRemoveAsTypeWithTheCorrectType() { Map map = new HashMap<>(Map.of("a", 5, "b", "a string", "c", Boolean.TRUE, "d", 1.0)); @@ -903,96 +884,6 @@ public void testExtractRequiredEnum_HasValidationErrorOnMissingSetting() { assertThat(validationException.validationErrors().get(0), is("[testscope] does not contain the required setting [missing_key]")); } - public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingResults_IsEmpty() { - var service = mock(InferenceService.class); - - var model = mock(Model.class); - when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(7); - listener.onResponse(new InferenceTextEmbeddingFloatResults(List.of())); - - return Void.TYPE; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); - - PlainActionFuture listener = new PlainActionFuture<>(); - getEmbeddingSize(model, service, listener); - - var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - - assertThat(thrownException.getMessage(), is("Could not determine embedding size")); - assertThat(thrownException.getCause().getMessage(), is("Embeddings list is empty")); - } - - public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingByteResults_IsEmpty() { - var service = mock(InferenceService.class); - - var model = mock(Model.class); - when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(7); - listener.onResponse(new InferenceTextEmbeddingByteResults(List.of())); - - return Void.TYPE; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); - - PlainActionFuture listener = new PlainActionFuture<>(); - getEmbeddingSize(model, service, listener); - - var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - - assertThat(thrownException.getMessage(), is("Could not determine embedding size")); - assertThat(thrownException.getCause().getMessage(), is("Embeddings list is empty")); - } - - public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingResults() { - var service = mock(InferenceService.class); - - var model = mock(Model.class); - when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); - - var textEmbedding = TextEmbeddingResultsTests.createRandomResults(); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(7); - listener.onResponse(textEmbedding); - - return Void.TYPE; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); - - PlainActionFuture listener = new PlainActionFuture<>(); - getEmbeddingSize(model, service, listener); - - var size = listener.actionGet(TIMEOUT); - - assertThat(size, is(textEmbedding.embeddings().get(0).getSize())); - } - - public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingByteResults() { - var service = mock(InferenceService.class); - - var model = mock(Model.class); - when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); - - var textEmbedding = InferenceTextEmbeddingByteResultsTests.createRandomResults(); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(7); - listener.onResponse(textEmbedding); - - return Void.TYPE; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); - - PlainActionFuture listener = new PlainActionFuture<>(); - getEmbeddingSize(model, service, listener); - - var size = listener.actionGet(TIMEOUT); - - assertThat(size, is(textEmbedding.embeddings().get(0).getSize())); - } - private static Map modifiableMap(Map aMap) { return new HashMap<>(aMap); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index 1ca50d1887ee1..2a8727355448e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -18,7 +18,6 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; -import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; @@ -38,7 +37,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; @@ -52,7 +50,6 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; -import org.hamcrest.MatcherAssert; import org.junit.After; import org.junit.Before; @@ -261,72 +258,6 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } } - public void testCheckModelConfig() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool)) { - @Override - public void doInfer( - Model model, - InferenceInputs inputs, - Map taskSettings, - InputType inputType, - TimeValue timeout, - ActionListener listener - ) { - InferenceTextEmbeddingFloatResults results = new InferenceTextEmbeddingFloatResults( - List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { -0.028680f, 0.022033f })) - ); - - listener.onResponse(results); - } - }) { - Map serviceSettingsMap = new HashMap<>(); - serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); - serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); - serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); - serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536); - - Map taskSettingsMap = new HashMap<>(); - - Map secretSettingsMap = new HashMap<>(); - secretSettingsMap.put("api_key", "secret"); - - var model = AlibabaCloudSearchEmbeddingsModelTests.createModel( - "service", - TaskType.TEXT_EMBEDDING, - serviceSettingsMap, - taskSettingsMap, - secretSettingsMap - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - Map expectedServiceSettingsMap = new HashMap<>(); - expectedServiceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); - expectedServiceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); - expectedServiceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); - expectedServiceSettingsMap.put(ServiceFields.SIMILARITY, "DOT_PRODUCT"); - expectedServiceSettingsMap.put(ServiceFields.DIMENSIONS, 2); - - Map expectedTaskSettingsMap = new HashMap<>(); - - Map expectedSecretSettingsMap = new HashMap<>(); - expectedSecretSettingsMap.put("api_key", "secret"); - - var expectedModel = AlibabaCloudSearchEmbeddingsModelTests.createModel( - "service", - TaskType.TEXT_EMBEDDING, - expectedServiceSettingsMap, - expectedTaskSettingsMap, - expectedSecretSettingsMap - ); - - MatcherAssert.assertThat(result, is(expectedModel)); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 6505c280c295a..2a373c8220448 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -51,7 +51,6 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -1039,232 +1038,6 @@ public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { } } - public void testCheckModelConfig_IncludesMaxTokens_ForEmbeddingsModel() throws IOException { - var sender = mock(Sender.class); - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( - ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), - mockClusterServiceEmpty() - ); - - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { - try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { - var results = new InferenceTextEmbeddingFloatResults( - List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) - ); - requestSender.enqueue(results); - - var model = AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - null, - false, - 100, - null, - null, - "access", - "secret" - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - 2, - false, - 100, - SimilarityMeasure.COSINE, - null, - "access", - "secret" - ) - ) - ); - var inputStrings = requestSender.getInputs(); - - MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); - } - } - } - - public void testCheckModelConfig_HasSimilarity_ForEmbeddingsModel() throws IOException { - var sender = mock(Sender.class); - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( - ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), - mockClusterServiceEmpty() - ); - - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { - try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { - var results = new InferenceTextEmbeddingFloatResults( - List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) - ); - requestSender.enqueue(results); - - var model = AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - null, - false, - null, - SimilarityMeasure.COSINE, - null, - "access", - "secret" - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - 2, - false, - null, - SimilarityMeasure.COSINE, - null, - "access", - "secret" - ) - ) - ); - var inputStrings = requestSender.getInputs(); - - MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); - } - } - } - - public void testCheckModelConfig_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException { - var sender = mock(Sender.class); - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( - ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), - mockClusterServiceEmpty() - ); - - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { - try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { - var results = new InferenceTextEmbeddingFloatResults( - List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) - ); - requestSender.enqueue(results); - - var model = AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - 3, - true, - null, - null, - null, - "access", - "secret" - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - exception.getMessage(), - is( - "The retrieved embeddings size [2] does not match the size specified in the settings [3]. " - + "Please recreate the [id] configuration with the correct dimensions" - ) - ); - - var inputStrings = requestSender.getInputs(); - MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); - } - } - } - - public void testCheckModelConfig_ReturnsNewModelReference_AndDoesNotSendDimensionsField_WhenNotSetByUser() throws IOException { - var sender = mock(Sender.class); - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( - ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), - mockClusterServiceEmpty() - ); - - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { - try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { - var results = new InferenceTextEmbeddingFloatResults( - List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) - ); - requestSender.enqueue(results); - - var model = AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - 100, - false, - null, - SimilarityMeasure.COSINE, - null, - "access", - "secret" - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - 2, - false, - null, - SimilarityMeasure.COSINE, - null, - "access", - "secret" - ) - ) - ); - var inputStrings = requestSender.getInputs(); - - MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); - } - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index cebea7901b956..80bf0b804c715 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -54,7 +54,6 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettingsTests; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -842,138 +841,6 @@ public void testParsePersistedConfig_WithoutSecretsCreatesChatCompletionModel() } } - public void testCheckModelConfig_ForEmbeddingsModel_Works() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingResultJson)); - - var model = AzureAiStudioEmbeddingsModelTests.createModel( - "id", - getUrl(webServer), - AzureAiStudioProvider.OPENAI, - AzureAiStudioEndpointType.TOKEN, - "apikey", - null, - false, - null, - null, - null, - null - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AzureAiStudioEmbeddingsModelTests.createModel( - "id", - getUrl(webServer), - AzureAiStudioProvider.OPENAI, - AzureAiStudioEndpointType.TOKEN, - "apikey", - 2, - false, - null, - SimilarityMeasure.DOT_PRODUCT, - null, - null - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big")))); - } - } - - public void testCheckModelConfig_ForEmbeddingsModel_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingResultJson)); - - var model = AzureAiStudioEmbeddingsModelTests.createModel( - "id", - getUrl(webServer), - AzureAiStudioProvider.OPENAI, - AzureAiStudioEndpointType.TOKEN, - "apikey", - 3, - true, - null, - null, - null, - null - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - exception.getMessage(), - is( - "The retrieved embeddings size [2] does not match the size specified in the settings [3]. " - + "Please recreate the [id] configuration with the correct dimensions" - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "dimensions", 3))); - } - } - - public void testCheckModelConfig_WorksForChatCompletionsModel() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testChatCompletionResultJson)); - - var model = AzureAiStudioChatCompletionModelTests.createModel( - "id", - getUrl(webServer), - AzureAiStudioProvider.OPENAI, - AzureAiStudioEndpointType.TOKEN, - "apikey", - null, - null, - null, - null, - null - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AzureAiStudioChatCompletionModelTests.createModel( - "id", - getUrl(webServer), - AzureAiStudioProvider.OPENAI, - AzureAiStudioEndpointType.TOKEN, - "apikey", - null, - null, - null, - AzureAiStudioChatCompletionTaskSettings.DEFAULT_MAX_NEW_TOKENS, - null - ) - ) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index e67a5dac0e7c2..d1c38409b7dfc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -46,7 +46,6 @@ import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests; import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -844,356 +843,6 @@ public void testInfer_SendsRequest() throws IOException, URISyntaxException { } } - public void testCheckModelConfig_IncludesMaxTokens() throws IOException, URISyntaxException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - null, - false, - 100, - null, - "user", - "apikey", - null, - "id" - ); - model.setUri(new URI(getUrl(webServer))); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - 2, - false, - 100, - SimilarityMeasure.DOT_PRODUCT, - "user", - "apikey", - null, - "id" - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user"))); - } - } - - public void testCheckModelConfig_HasSimilarity() throws IOException, URISyntaxException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - null, - false, - null, - SimilarityMeasure.COSINE, - "user", - "apikey", - null, - "id" - ); - model.setUri(new URI(getUrl(webServer))); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - 2, - false, - null, - SimilarityMeasure.COSINE, - "user", - "apikey", - null, - "id" - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user"))); - } - } - - public void testCheckModelConfig_AddsDefaultSimilarityDotProduct() throws IOException, URISyntaxException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - null, - false, - null, - null, - "user", - "apikey", - null, - "id" - ); - model.setUri(new URI(getUrl(webServer))); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - 2, - false, - null, - SimilarityMeasure.DOT_PRODUCT, - "user", - "apikey", - null, - "id" - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user"))); - } - } - - public void testCheckModelConfig_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException, URISyntaxException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - 3, - true, - 100, - null, - "user", - "apikey", - null, - "id" - ); - model.setUri(new URI(getUrl(webServer))); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - exception.getMessage(), - is( - "The retrieved embeddings size [2] does not match the size specified in the settings [3]. " - + "Please recreate the [id] configuration with the correct dimensions" - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user", "dimensions", 3))); - } - } - - public void testCheckModelConfig_ReturnsNewModelReference_AndDoesNotSendDimensionsField_WhenNotSetByUser() throws IOException, - URISyntaxException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - 100, - false, - 100, - null, - "user", - "apikey", - null, - "id" - ); - model.setUri(new URI(getUrl(webServer))); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - 2, - false, - 100, - SimilarityMeasure.DOT_PRODUCT, - "user", - "apikey", - null, - "id" - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user"))); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 90e5dc6890c45..50f7fbf081a9c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -882,198 +882,6 @@ public void testInfer_SendsRequest() throws IOException { } } - public void testCheckModelConfig_UpdatesDimensions() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], - "embeddings": { - "float": [ - [ - 0.123, - -0.123 - ] - ] - }, - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 1 - } - }, - "response_type": "embeddings_by_type" - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = CohereEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - null, - null - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - CohereEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - null, - null - ) - ) - ); - } - } - - public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], - "embeddings": { - "float": [ - [ - 0.123, - -0.123 - ] - ] - }, - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 1 - } - }, - "response_type": "embeddings_by_type" - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = CohereEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - null, - null - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - CohereEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - null, - null, - SimilarityMeasure.DOT_PRODUCT - ) - ) - ); - } - } - - public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], - "embeddings": { - "float": [ - [ - 0.123, - -0.123 - ] - ] - }, - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 1 - } - }, - "response_type": "embeddings_by_type" - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = CohereEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - null, - null, - SimilarityMeasure.COSINE - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - CohereEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - null, - null, - SimilarityMeasure.COSINE - ) - ) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index d1ce79b863c61..cd805b9594f6a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -1451,67 +1451,6 @@ public void onFailure(Exception e) { assertEquals("text_field", putConfig.getInput().getFieldNames().get(0)); } - public void testParseRequestConfigEland_SetsDimensionsToOne() { - var client = mock(Client.class); - doAnswer(invocationOnMock -> { - @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocationOnMock - .getArguments()[2]; - listener.onResponse( - new InferModelAction.Response(List.of(new MlTextEmbeddingResults("field", new double[] { 0.1 }, false)), "id", true) - ); - - var request = (InferModelAction.Request) invocationOnMock.getArguments()[1]; - assertThat(request.getId(), is("custom-model")); - return Void.TYPE; - }).when(client).execute(eq(InferModelAction.INSTANCE), any(), any()); - when(client.threadPool()).thenReturn(threadPool); - - var service = createService(client); - - var serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings( - 1, - 4, - "custom-model", - null, - null, - 1, - SimilarityMeasure.COSINE, - DenseVectorFieldMapper.ElementType.FLOAT - ); - var taskType = TaskType.TEXT_EMBEDDING; - var expectedModel = new CustomElandEmbeddingModel( - randomInferenceEntityId, - taskType, - ElasticsearchInternalService.NAME, - serviceSettings, - null - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig( - new CustomElandEmbeddingModel( - randomInferenceEntityId, - taskType, - ElasticsearchInternalService.NAME, - new CustomElandInternalTextEmbeddingServiceSettings( - 1, - 4, - "custom-model", - null, - null, - null, - SimilarityMeasure.COSINE, - DenseVectorFieldMapper.ElementType.FLOAT - ), - null - ), - listener - ); - var model = listener.actionGet(TimeValue.THIRTY_SECONDS); - assertThat(model, is(expectedModel)); - } - public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() { { assertFalse( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index d0760a583df29..9be097f49ffab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -955,132 +955,6 @@ public void testInfer_ResourceNotFound() throws IOException { } } - public void testCheckModelConfig_UpdatesDimensions() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var similarityMeasure = SimilarityMeasure.DOT_PRODUCT; - var modelId = "model"; - var apiKey = "apiKey"; - - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "embeddings": [ - { - "values": [ - 0.0123, - -0.0123 - ] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = GoogleAiStudioEmbeddingsModelTests.createModel(getUrl(webServer), modelId, apiKey, 1, similarityMeasure); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - // Updates dimensions to two as two embeddings were returned instead of one as specified before - assertThat( - result, - is(GoogleAiStudioEmbeddingsModelTests.createModel(getUrl(webServer), modelId, apiKey, 2, similarityMeasure)) - ); - } - } - - public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var oneDimension = 1; - var modelId = "model"; - var apiKey = "apiKey"; - - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "embeddings": [ - { - "values": [ - 0.0123 - ] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = GoogleAiStudioEmbeddingsModelTests.createModel(getUrl(webServer), modelId, apiKey, oneDimension, null); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - assertThat( - result, - is( - GoogleAiStudioEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - apiKey, - oneDimension, - SimilarityMeasure.DOT_PRODUCT - ) - ) - ); - } - } - - public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var oneDimension = 1; - var modelId = "model"; - var apiKey = "apiKey"; - - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "embeddings": [ - { - "values": [ - 0.0123 - ] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = GoogleAiStudioEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - apiKey, - oneDimension, - SimilarityMeasure.COSINE - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - assertThat( - result, - is( - GoogleAiStudioEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - apiKey, - oneDimension, - SimilarityMeasure.COSINE - ) - ) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index f3137d7011cec..50f7408e63a6b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -630,90 +630,6 @@ public void testInfer_SendsElserRequest() throws IOException { } } - public void testCheckModelConfig_IncludesMaxTokens() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "embeddings": [ - [ - -0.0123 - ] - ] - { - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.DOT_PRODUCT); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.DOT_PRODUCT)) - ); - } - } - - public void testCheckModelConfig_UsesUserSpecifiedSimilarity() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "embeddings": [ - [ - -0.0123 - ] - ] - { - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 2, SimilarityMeasure.COSINE); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.COSINE)) - ); - } - } - - public void testCheckModelConfig_DefaultsSimilarityToCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "embeddings": [ - [ - -0.0123 - ] - ] - { - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 2, null); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.COSINE)) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 99b7b3868b7f4..c0eb93de8062e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -795,169 +795,6 @@ public void testInfer_ResourceNotFound() throws IOException { } } - public void testCheckModelConfig_UpdatesDimensions() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var similarityMeasure = SimilarityMeasure.DOT_PRODUCT; - - try (var service = new IbmWatsonxServiceWithoutAuth(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - 1, - similarityMeasure - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - // Updates dimensions to two as two embeddings were returned instead of one as specified before - assertThat( - result, - is( - IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - 2, - similarityMeasure - ) - ) - ); - } - } - - public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var twoDimension = 2; - - try (var service = new IbmWatsonxServiceWithoutAuth(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - null - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - assertThat( - result, - is( - IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - SimilarityMeasure.DOT_PRODUCT - ) - ) - ); - } - } - - public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var twoDimension = 2; - - try (var service = new IbmWatsonxServiceWithoutAuth(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - SimilarityMeasure.COSINE - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - assertThat( - result, - is( - IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - SimilarityMeasure.COSINE - ) - ) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index 2aeb0447f9c78..801861e3ed0fd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -800,178 +800,6 @@ public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException verifyNoMoreInteractions(sender); } - public void testCheckModelConfig_UpdatesDimensions() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "model": "jina-clip-v2", - "object": "list", - "usage": { - "total_tokens": 5, - "prompt_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - "jina-clip-v2" - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - "jina-clip-v2" - ) - ) - ); - } - } - - public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "model": "jina-clip-v2", - "object": "list", - "usage": { - "total_tokens": 5, - "prompt_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - "jina-clip-v2", - null - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - "jina-clip-v2", - SimilarityMeasure.DOT_PRODUCT - ) - ) - ); - } - } - - public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "model": "jina-clip-v2", - "object": "list", - "usage": { - "total_tokens": 5, - "prompt_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - "jina-clip-v2", - SimilarityMeasure.COSINE - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - "jina-clip-v2", - SimilarityMeasure.COSINE - ) - ) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { testUpdateModelWithEmbeddingDetails_Successful(null); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 95ac2cde0e31b..e82afc78e9d88 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -46,7 +46,6 @@ import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -501,34 +500,6 @@ public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenC } } - public void testCheckModelConfig_ForEmbeddingsModel_Works() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingResultJson)); - - var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); - model.setURI(getUrl(webServer)); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is(MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", 2, null, SimilarityMeasure.DOT_PRODUCT, null)) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( - requestMap, - Matchers.is(Map.of("input", List.of("how big"), "encoding_format", "float", "model", "mistral-embed")) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index b31cdf4f9d592..791f0e02c9443 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -51,7 +51,6 @@ import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests; import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -61,7 +60,6 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.concurrent.TimeUnit; import static org.elasticsearch.ExceptionsHelper.unwrapCause; @@ -1232,393 +1230,6 @@ public void testSupportsStreaming() throws IOException { } } - public void testCheckModelConfig_IncludesMaxTokens() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", 100); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat(result, is(OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", 100, 2))); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user"))); - } - } - - public void testCheckModelConfig_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", null, 100, 3, true); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - exception.getMessage(), - is( - "The retrieved embeddings size [2] does not match the size specified in the settings [3]. " - + "Please recreate the [id] configuration with the correct dimensions" - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( - requestMap, - Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user", "dimensions", 3)) - ); - } - } - - public void testCheckModelConfig_ReturnsModelWithDimensionsSetTo2_AndDocProductSet_IfDimensionsSetByUser_ButSetToNull() - throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", null, 100, null, true); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var returnedModel = listener.actionGet(TIMEOUT); - assertThat( - returnedModel, - is( - OpenAiEmbeddingsModelTests.createModel( - getUrl(webServer), - "org", - "secret", - "model", - "user", - SimilarityMeasure.DOT_PRODUCT, - 100, - 2, - true - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - // since dimensions were null they should not be sent in the request - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user"))); - } - } - - public void testCheckModelConfig_ReturnsModelWithSameDimensions_AndDocProductSet_IfDimensionsSetByUser_AndTheyMatchReturnedSize() - throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", null, 100, 2, true); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var returnedModel = listener.actionGet(TIMEOUT); - assertThat( - returnedModel, - is( - OpenAiEmbeddingsModelTests.createModel( - getUrl(webServer), - "org", - "secret", - "model", - "user", - SimilarityMeasure.DOT_PRODUCT, - 100, - 2, - true - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( - requestMap, - Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user", "dimensions", 2)) - ); - } - } - - public void testCheckModelConfig_ReturnsNewModelReference_AndDoesNotSendDimensionsField_WhenNotSetByUser() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", null, 100, 100, false); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var returnedModel = listener.actionGet(TIMEOUT); - assertThat( - returnedModel, - is( - OpenAiEmbeddingsModelTests.createModel( - getUrl(webServer), - "org", - "secret", - "model", - "user", - SimilarityMeasure.DOT_PRODUCT, - 100, - 2, - false - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user"))); - } - } - - public void testCheckModelConfig_ReturnsNewModelReference_SetsSimilarityToDocProduct_WhenNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", null, 100, 100, false); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var returnedModel = listener.actionGet(TIMEOUT); - assertThat( - returnedModel, - is( - OpenAiEmbeddingsModelTests.createModel( - getUrl(webServer), - "org", - "secret", - "model", - "user", - SimilarityMeasure.DOT_PRODUCT, - 100, - 2, - false - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user"))); - } - } - - public void testCheckModelConfig_ReturnsNewModelReference_DoesNotOverrideSimilarity_WhenNotNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel( - getUrl(webServer), - "org", - "secret", - "model", - "user", - SimilarityMeasure.COSINE, - 100, - 100, - false - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var returnedModel = listener.actionGet(TIMEOUT); - assertThat( - returnedModel, - is( - OpenAiEmbeddingsModelTests.createModel( - getUrl(webServer), - "org", - "secret", - "model", - "user", - SimilarityMeasure.COSINE, - 100, - 2, - false - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user"))); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { try (var service = createOpenAiService()) { var model = createCompletionModel( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java index 89ab07d25e83d..bd52bdd52ab3e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; @@ -26,6 +27,9 @@ import static org.mockito.MockitoAnnotations.openMocks; public class ChatCompletionModelValidatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; + @Mock private ServiceIntegrationValidator mockServiceIntegrationValidator; @Mock @@ -48,14 +52,14 @@ public void setup() { public void testValidate_ServiceIntegrationValidatorThrowsException() { doThrow(ElasticsearchStatusException.class).when(mockServiceIntegrationValidator) - .validate(eq(mockInferenceService), eq(mockModel), any()); + .validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); assertThrows( ElasticsearchStatusException.class, - () -> { underTest.validate(mockInferenceService, mockModel, mockActionListener); } + () -> { underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } ); - verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); verify(mockActionListener).delegateFailureAndWrap(any()); verifyNoMoreInteractions( mockServiceIntegrationValidator, @@ -70,14 +74,14 @@ public void testValidate_ChatCompletionDetailsUpdated() { when(mockActionListener.delegateFailureAndWrap(any())).thenCallRealMethod(); when(mockInferenceService.updateModelWithChatCompletionDetails(mockModel)).thenReturn(mockModel); doAnswer(ans -> { - ActionListener responseListener = ans.getArgument(2); + ActionListener responseListener = ans.getArgument(3); responseListener.onResponse(mockInferenceServiceResults); return null; - }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); - underTest.validate(mockInferenceService, mockModel, mockActionListener); + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); - verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); verify(mockActionListener).delegateFailureAndWrap(any()); verify(mockActionListener).onResponse(mockModel); verify(mockInferenceService).updateModelWithChatCompletionDetails(mockModel); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java index a854bbdec507a..19ea0bedaaea5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java @@ -16,12 +16,12 @@ public class ModelValidatorBuilderTests extends ESTestCase { public void testBuildModelValidator_NullTaskType() { - assertThrows(IllegalArgumentException.class, () -> { ModelValidatorBuilder.buildModelValidator(null); }); + assertThrows(IllegalArgumentException.class, () -> { ModelValidatorBuilder.buildModelValidator(null, false); }); } public void testBuildModelValidator_ValidTaskType() { taskTypeToModelValidatorClassMap().forEach((taskType, modelValidatorClass) -> { - assertThat(ModelValidatorBuilder.buildModelValidator(taskType), isA(modelValidatorClass)); + assertThat(ModelValidatorBuilder.buildModelValidator(taskType, false), isA(modelValidatorClass)); }); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidatorTests.java index f02c4662d49e4..1103a26c45b2e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidatorTests.java @@ -9,13 +9,13 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -44,6 +44,7 @@ public class SimpleChatCompletionServiceIntegrationValidatorTests extends ESTest null, null ); + private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; @Mock private InferenceService mockInferenceService; @@ -67,9 +68,12 @@ public void setup() { public void testValidate_ServiceThrowsException() { doThrow(ElasticsearchStatusException.class).when(mockInferenceService) - .unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(InferenceAction.Request.DEFAULT_TIMEOUT), any()); + .unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); - assertThrows(ElasticsearchStatusException.class, () -> underTest.validate(mockInferenceService, mockModel, mockActionListener)); + assertThrows( + ElasticsearchStatusException.class, + () -> underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener) + ); verifyCallToService(); } @@ -117,10 +121,9 @@ private void mockSuccessfulCallToService(InferenceServiceResults result) { ActionListener responseListener = ans.getArgument(3); responseListener.onResponse(result); return null; - }).when(mockInferenceService) - .unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(InferenceAction.Request.DEFAULT_TIMEOUT), any()); + }).when(mockInferenceService).unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); - underTest.validate(mockInferenceService, mockModel, mockActionListener); + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } private void mockNullResponseFromService() { @@ -132,19 +135,13 @@ private void mockFailureResponseFromService(Exception exception) { ActionListener responseListener = ans.getArgument(3); responseListener.onFailure(exception); return null; - }).when(mockInferenceService) - .unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(InferenceAction.Request.DEFAULT_TIMEOUT), any()); + }).when(mockInferenceService).unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); - underTest.validate(mockInferenceService, mockModel, mockActionListener); + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } private void verifyCallToService() { - verify(mockInferenceService).unifiedCompletionInfer( - eq(mockModel), - eq(EXPECTED_REQUEST), - eq(InferenceAction.Request.DEFAULT_TIMEOUT), - any() - ); + verify(mockInferenceService).unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); verifyNoMoreInteractions(mockInferenceService, mockModel, mockActionListener, mockInferenceServiceResults); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java index b14a1f8f3cc77..c1679fd7c68da 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; @@ -27,6 +28,9 @@ import static org.mockito.MockitoAnnotations.openMocks; public class SimpleModelValidatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; + @Mock private ServiceIntegrationValidator mockServiceIntegrationValidator; @Mock @@ -49,11 +53,11 @@ public void setup() { public void testValidate_ServiceIntegrationValidatorThrowsException() { doThrow(ElasticsearchStatusException.class).when(mockServiceIntegrationValidator) - .validate(eq(mockInferenceService), eq(mockModel), any()); + .validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); assertThrows( ElasticsearchStatusException.class, - () -> { underTest.validate(mockInferenceService, mockModel, mockActionListener); } + () -> { underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } ); verifyInteractions(); } @@ -66,16 +70,16 @@ public void testValidate_ServiceReturnsInferenceServiceResults() { private void mockCallToServiceIntegrationValidator(InferenceServiceResults results) { doAnswer(ans -> { - ActionListener responseListener = ans.getArgument(2); + ActionListener responseListener = ans.getArgument(3); responseListener.onResponse(results); return null; - }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); - underTest.validate(mockInferenceService, mockModel, mockActionListener); + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } private void verifyInteractions() { - verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); verify(mockActionListener).delegateFailureAndWrap(any()); verifyNoMoreInteractions(mockServiceIntegrationValidator, mockInferenceService, mockModel, mockActionListener); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java index 22ef35c3a46d3..2d91f4a185c5f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java @@ -9,13 +9,13 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.junit.Before; import org.mockito.Mock; @@ -35,6 +35,7 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase { private static final List TEST_INPUT = List.of("how big"); private static final String TEST_QUERY = "test query"; + private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; @Mock private InferenceService mockInferenceService; @@ -67,13 +68,13 @@ public void testValidate_ServiceThrowsException() { eq(false), eq(Map.of()), eq(InputType.INGEST), - eq(InferenceAction.Request.DEFAULT_TIMEOUT), + eq(TIMEOUT), any() ); assertThrows( ElasticsearchStatusException.class, - () -> { underTest.validate(mockInferenceService, mockModel, mockActionListener); } + () -> { underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } ); verifyCallToService(false); @@ -101,18 +102,9 @@ private void mockSuccessfulCallToService(String query, InferenceServiceResults r responseListener.onResponse(result); return null; }).when(mockInferenceService) - .infer( - eq(mockModel), - eq(query), - eq(TEST_INPUT), - eq(false), - eq(Map.of()), - eq(InputType.INGEST), - eq(InferenceAction.Request.DEFAULT_TIMEOUT), - any() - ); + .infer(eq(mockModel), eq(query), eq(TEST_INPUT), eq(false), eq(Map.of()), eq(InputType.INGEST), eq(TIMEOUT), any()); - underTest.validate(mockInferenceService, mockModel, mockActionListener); + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } private void verifyCallToService(boolean withQuery) { @@ -124,7 +116,7 @@ private void verifyCallToService(boolean withQuery) { eq(false), eq(Map.of()), eq(InputType.INGEST), - eq(InferenceAction.Request.DEFAULT_TIMEOUT), + eq(TIMEOUT), any() ); verifyNoMoreInteractions(mockInferenceService, mockModel, mockActionListener, mockInferenceServiceResults); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java index d608b42841305..22f19b6ce5770 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; @@ -37,6 +38,9 @@ import static org.mockito.MockitoAnnotations.openMocks; public class TextEmbeddingModelValidatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; + @Mock private ServiceIntegrationValidator mockServiceIntegrationValidator; @Mock @@ -64,14 +68,14 @@ public void setup() { public void testValidate_ServiceIntegrationValidatorThrowsException() { doThrow(ElasticsearchStatusException.class).when(mockServiceIntegrationValidator) - .validate(eq(mockInferenceService), eq(mockModel), any()); + .validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); assertThrows( ElasticsearchStatusException.class, - () -> { underTest.validate(mockInferenceService, mockModel, mockActionListener); } + () -> { underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } ); - verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); verify(mockActionListener).delegateFailureAndWrap(any()); verifyNoMoreInteractions(mockServiceIntegrationValidator, mockInferenceService, mockModel, mockActionListener, mockServiceSettings); } @@ -143,14 +147,14 @@ private void mockSuccessfulValidation(Boolean dimensionsSetByUser) { private void mockCallToServiceIntegrationValidator(InferenceServiceResults results) { doAnswer(ans -> { - ActionListener responseListener = ans.getArgument(2); + ActionListener responseListener = ans.getArgument(3); responseListener.onResponse(results); return null; - }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); - underTest.validate(mockInferenceService, mockModel, mockActionListener); + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); - verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); verify(mockActionListener).delegateFailureAndWrap(any()); } } From b3a913f7940ced8a7fcf74bbd774b6cfa29b59e3 Mon Sep 17 00:00:00 2001 From: Dan Rubinstein Date: Thu, 20 Feb 2025 11:16:39 -0500 Subject: [PATCH 2/8] Update docs/changelog/123044.yaml --- docs/changelog/123044.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/123044.yaml diff --git a/docs/changelog/123044.yaml b/docs/changelog/123044.yaml new file mode 100644 index 0000000000000..2cb758c23edec --- /dev/null +++ b/docs/changelog/123044.yaml @@ -0,0 +1,5 @@ +pr: 123044 +summary: Adding validation to `ElasticsearchInternalService` +area: Machine Learning +type: enhancement +issues: [] From 6108381af87323f14344496eb66aaa1626412854 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 7 Mar 2025 18:55:48 +0000 Subject: [PATCH 3/8] [CI] Auto commit changes from spotless --- .../xpack/inference/services/openai/OpenAiServiceTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 2390ddd7c17a9..16e248dcb6197 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -63,7 +63,6 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; From eb21ca817b8d233652a38185848d3e4c277813d6 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Fri, 7 Mar 2025 14:37:14 -0500 Subject: [PATCH 4/8] Removing checkModelConfig --- .../inference/InferenceService.java | 11 -- .../services/voyageai/VoyageAIService.java | 13 -- .../elastic/ElasticInferenceServiceTests.java | 26 --- .../voyageai/VoyageAIServiceTests.java | 169 ------------------ 4 files changed, 219 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 968d61e511d42..7af4ec971333c 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -165,17 +165,6 @@ default void stop(Model model, ActionListener listener) { listener.onResponse(true); } - /** - * Optionally test the new model configuration in the inference service. - * This function should be called when the model is first created, the - * default action is to do nothing. - * @param model The new model - * @param listener The listener - */ - default void checkModelConfig(Model model, ActionListener listener) { - listener.onResponse(model); - }; - /** * Update a text embedding model's dimensions based on a provided embedding * size and set the default similarity if required. The default behaviour is to just return the model. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 16659f075c564..ec43ef6b5edca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -40,7 +40,6 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.voyageai.rerank.VoyageAIRerankModel; @@ -304,18 +303,6 @@ private static int getBatchSize(VoyageAIModel model) { return MODEL_BATCH_SIZES.getOrDefault(model.getServiceSettings().modelId(), DEFAULT_BATCH_SIZE); } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof VoyageAIEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index de3dac3577d44..6256b65b01222 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -316,32 +316,6 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testCheckModelConfig_ReturnsNewModelReference() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = createService(senderFactory, getUrl(webServer))) { - String responseJson = """ - { - "data": [ - { - "hello": 2.1259406, - "greet": 1.7073475 - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id"); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var returnedModel = listener.actionGet(TIMEOUT); - assertThat(returnedModel, is(ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id"))); - } - } - public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException { var sender = mock(Sender.class); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 3a5fce350046e..3987cc43e94a6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -793,175 +793,6 @@ public void testInfer_ThrowsErrorWhenModelIsNotVoyageAIModel() throws IOExceptio verifyNoMoreInteractions(sender); } - public void testCheckModelConfig_UpdatesDimensions() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "model": "voyage-3-large", - "object": "list", - "usage": { - "total_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = VoyageAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - "voyage-3-large" - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - VoyageAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - "voyage-3-large" - ) - ) - ); - } - } - - public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "model": "voyage-3-large", - "object": "list", - "usage": { - "total_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = VoyageAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - "voyage-3-large", - (SimilarityMeasure) null - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - VoyageAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - "voyage-3-large", - SimilarityMeasure.DOT_PRODUCT - ) - ) - ); - } - } - - public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "model": "voyage-3-large", - "object": "list", - "usage": { - "total_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = VoyageAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - "voyage-3-large", - SimilarityMeasure.COSINE - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - VoyageAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - "voyage-3-large", - SimilarityMeasure.COSINE - ) - ) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { testUpdateModelWithEmbeddingDetails_Successful(null); } From 4413a76a0ddfce4b66a5b4b9c0e8f39c7499096b Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Mon, 10 Mar 2025 14:26:56 -0400 Subject: [PATCH 5/8] Fixing IT --- .../TestSparseInferenceServiceExtension.java | 20 +++++++++-- ...stStreamingCompletionServiceExtension.java | 36 +++++++++++++++++-- 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 4e10ce45efeac..a295288a21d75 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -34,6 +34,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.io.IOException; @@ -61,7 +62,7 @@ public TestSparseModel(String inferenceEntityId, TestServiceSettings serviceSett public static class TestInferenceService extends AbstractTestInferenceService { public static final String NAME = "test_service"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING); + private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING); public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {} @@ -110,7 +111,8 @@ public void infer( ActionListener listener ) { switch (model.getConfigurations().getTaskType()) { - case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeResults(input)); + case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeSparseEmbeddingResults(input)); + case TEXT_EMBEDDING -> listener.onResponse(makeTextEmbeddingResults(input)); default -> listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), @@ -151,7 +153,7 @@ public void chunkedInfer( } } - private SparseEmbeddingResults makeResults(List input) { + private SparseEmbeddingResults makeSparseEmbeddingResults(List input) { var embeddings = new ArrayList(); for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); @@ -163,6 +165,18 @@ private SparseEmbeddingResults makeResults(List input) { return new SparseEmbeddingResults(embeddings); } + private TextEmbeddingFloatResults makeTextEmbeddingResults(List input) { + var embeddings = new ArrayList(); + for (int i = 0; i < input.size(); i++) { + var values = new float[5]; + for (int j = 0; j < 5; j++) { + values[j] = random.nextFloat(); + } + embeddings.add(new TextEmbeddingFloatResults.Embedding(values)); + } + return new TextEmbeddingFloatResults(embeddings); + } + private List makeChunkedResults(List input) { List results = new ArrayList<>(); for (int i = 0; i < input.size(); i++) { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 8c876e9947bba..51b5387ca7aa1 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -34,8 +34,10 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import java.io.IOException; +import java.util.ArrayList; import java.util.EnumSet; import java.util.HashMap; import java.util.Iterator; @@ -57,7 +59,11 @@ public static class TestInferenceService extends AbstractTestInferenceService { private static final String NAME = "streaming_completion_test_service"; private static final Set supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); + private static final EnumSet supportedTaskTypes = EnumSet.of( + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION, + TaskType.SPARSE_EMBEDDING + ); public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {} @@ -111,7 +117,19 @@ public void infer( ActionListener listener ) { switch (model.getConfigurations().getTaskType()) { - case COMPLETION -> listener.onResponse(makeResults(input)); + case COMPLETION -> listener.onResponse(makeChatCompletionResults(input)); + case SPARSE_EMBEDDING -> { + if (stream) { + listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } else { + listener.onResponse(makeTextEmbeddingResults(input)); + } + } default -> listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), @@ -139,7 +157,7 @@ public void unifiedCompletionInfer( } } - private StreamingChatCompletionResults makeResults(List input) { + private StreamingChatCompletionResults makeChatCompletionResults(List input) { var responseIter = input.stream().map(s -> s.toUpperCase(Locale.ROOT)).iterator(); return new StreamingChatCompletionResults(subscriber -> { subscriber.onSubscribe(new Flow.Subscription() { @@ -158,6 +176,18 @@ public void cancel() {} }); } + private TextEmbeddingFloatResults makeTextEmbeddingResults(List input) { + var embeddings = new ArrayList(); + for (int i = 0; i < input.size(); i++) { + var values = new float[5]; + for (int j = 0; j < 5; j++) { + values[j] = random.nextFloat(); + } + embeddings.add(new TextEmbeddingFloatResults.Embedding(values)); + } + return new TextEmbeddingFloatResults(embeddings); + } + private InferenceServiceResults.Result completionChunk(String delta) { return new InferenceServiceResults.Result() { @Override From 1033c368b9ab0dcaef6549f8449afd57b0688ee1 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 13 Mar 2025 14:08:59 +0000 Subject: [PATCH 6/8] [CI] Auto commit changes from spotless --- .../elasticsearch/xpack/inference/services/ServiceUtils.java | 2 +- .../xpack/inference/services/ServiceUtilsTests.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 49478c4592ed7..94fe594d19041 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -716,7 +716,7 @@ public static ElasticsearchStatusException createInvalidModelException(Model mod RestStatus.INTERNAL_SERVER_ERROR ); } - + public static SecureString apiKey(@Nullable ApiKeySecrets secrets) { // To avoid a possible null pointer throughout the code we'll create a noop api key of an empty array return secrets == null ? new SecureString(new char[0]) : secrets.apiKey(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index 2b996cd78ad28..30e6a86bc1d96 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -883,7 +883,7 @@ public void testExtractRequiredEnum_HasValidationErrorOnMissingSetting() { assertThat(validationException.validationErrors().size(), is(1)); assertThat(validationException.validationErrors().get(0), is("[testscope] does not contain the required setting [missing_key]")); } - + private static Map modifiableMap(Map aMap) { return new HashMap<>(aMap); } From 7766b8e439024286284c422a949c2aa301dafd6b Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Thu, 13 Mar 2025 10:27:02 -0400 Subject: [PATCH 7/8] Remove DeepSeek checkModelConfig and fix tests --- .../xpack/inference/InferenceGetServicesIT.java | 14 +++++++++++--- .../services/deepseek/DeepSeekService.java | 7 ------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 6f9a550481049..a5c34dc2f14d5 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -64,7 +64,7 @@ public void testGetServicesWithoutTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithTextEmbeddingTaskType() throws IOException { List services = getServices(TaskType.TEXT_EMBEDDING); - assertThat(services.size(), equalTo(15)); + assertThat(services.size(), equalTo(16)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -86,6 +86,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { "jinaai", "mistral", "openai", + "test_service", "text_embedding_test_service", "voyageai", "watsonxai" @@ -157,7 +158,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { List services = getServices(TaskType.SPARSE_EMBEDDING); - assertThat(services.size(), equalTo(5)); + assertThat(services.size(), equalTo(6)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -166,7 +167,14 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { } assertArrayEquals( - List.of("alibabacloud-ai-search", "elastic", "elasticsearch", "hugging_face", "test_service").toArray(), + List.of( + "alibabacloud-ai-search", + "elastic", + "elasticsearch", + "hugging_face", + "streaming_completion_test_service", + "test_service" + ).toArray(), providers ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index 6338cee473cbd..259bb93fb05d2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -32,7 +32,6 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -173,12 +172,6 @@ public Set supportedStreamingTasks() { return SUPPORTED_TASK_TYPES_FOR_STREAMING; } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - private static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); From 059114f4eafeaf54cc40fc3c87b7fc6bcb1f8d66 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Fri, 14 Mar 2025 12:35:00 -0400 Subject: [PATCH 8/8] Cleaning up comments, updating validation input type, and moving model deployment starting to model validator --- ...stStreamingCompletionServiceExtension.java | 2 + .../TransportPutInferenceModelAction.java | 3 +- .../inference/services/ServiceUtils.java | 1 + .../BaseElasticsearchInternalService.java | 2 +- .../ElasticsearchInternalService.java | 1 + ...icsearchInternalServiceModelValidator.java | 37 +++- .../SimpleServiceIntegrationValidator.java | 2 +- .../AlibabaCloudSearchServiceTests.java | 1 + ...rchInternalServiceModelValidatorTests.java | 197 ++++++++++++++++++ ...impleServiceIntegrationValidatorTests.java | 17 +- 10 files changed, 253 insertions(+), 10 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 7604e06dbe547..e34018c5b8df1 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -131,6 +131,8 @@ public void infer( ) ); } else { + // Return text embedding results when creating a sparse_embedding inference endpoint to allow creation validation to + // pass. This is required to test that streaming fails for a sparse_embedding endpoint. listener.onResponse(makeTextEmbeddingResults(input)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 9dd713446f6e5..4357fa619954c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -207,7 +207,8 @@ private void parseAndStoreModel( delegate.onFailure(e); } } - ) + ), + timeout ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 1847312cf5f61..bdcadb2277c2b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -14,6 +14,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index 274b6e4e08410..84259d1c0be66 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -298,7 +298,7 @@ protected void maybeStartDeployment( InferModelAction.Request request, ActionListener listener ) { - if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { this.start(model, request.getInferenceTimeout(), listener.delegateFailureAndWrap((l, started) -> { client.execute(InferModelAction.INSTANCE, request, listener); })); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index bf119fd43300a..5e409e6ba5eb4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -58,6 +58,7 @@ import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import java.util.ArrayList; import java.util.Collections; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java index 3c4abafb401fb..aa3bb0ef9bcb0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java @@ -7,10 +7,12 @@ package org.elasticsearch.xpack.inference.services.validation; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; +import org.elasticsearch.rest.RestStatus; public class ElasticsearchInternalServiceModelValidator implements ModelValidator { @@ -22,9 +24,36 @@ public ElasticsearchInternalServiceModelValidator(ModelValidator modelValidator) @Override public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { - modelValidator.validate(service, model, timeout, listener.delegateResponse((l, exception) -> { - // TODO: Cleanup the below code - service.stop(model, ActionListener.wrap((v) -> listener.onFailure(exception), (e) -> listener.onFailure(exception))); - })); + service.start(model, timeout, ActionListener.wrap((modelDeploymentStarted) -> { + if (modelDeploymentStarted) { + try { + modelValidator.validate(service, model, timeout, listener.delegateResponse((l, exception) -> { + stopModelDeployment(service, model, l, exception); + })); + } catch (Exception e) { + stopModelDeployment(service, model, listener, e); + } + } else { + listener.onFailure( + new ElasticsearchStatusException("Could not deploy model for inference endpoint", RestStatus.INTERNAL_SERVER_ERROR) + ); + } + }, listener::onFailure)); + } + + private void stopModelDeployment(InferenceService service, Model model, ActionListener listener, Exception e) { + service.stop( + model, + ActionListener.wrap( + (v) -> listener.onFailure(e), + (ex) -> listener.onFailure( + new ElasticsearchStatusException( + "Model validation failed and model deployment could not be stopped", + RestStatus.INTERNAL_SERVER_ERROR, + ex + ) + ) + ) + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java index ffdc8a99d6e65..03ac5b95fddc5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java @@ -35,7 +35,7 @@ public void validate(InferenceService service, Model model, TimeValue timeout, A TEST_INPUT, false, Map.of(), - InputType.INGEST, + InputType.INTERNAL_INGEST, timeout, ActionListener.wrap(r -> { if (r != null) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index e9b529740d804..3a817da2a6a4d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java new file mode 100644 index 0000000000000..d583d2d7f5661 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java @@ -0,0 +1,197 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.validation; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.Model; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.openMocks; + +public class ElasticsearchInternalServiceModelValidatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; + private static final String MODEL_VALIDATION_AND_STOP_FAILED_MESSAGE = + "Model validation failed and model deployment could not be stopped"; + + @Mock + private ModelValidator mockModelValidator; + @Mock + private InferenceService mockInferenceService; + @Mock + private Model mockModel; + @Mock + private ActionListener mockActionListener; + + private ElasticsearchInternalServiceModelValidator underTest; + + @Before + public void setup() { + openMocks(this); + + underTest = new ElasticsearchInternalServiceModelValidator(mockModelValidator); + + when(mockActionListener.delegateResponse(any())).thenCallRealMethod(); + } + + public void testValidate_ModelDeploymentThrowsException() { + doThrow(ElasticsearchStatusException.class).when(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + + assertThrows( + ElasticsearchStatusException.class, + () -> { underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } + ); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + public void testValidate_ModelDeploymentReturnsFalse() { + mockModelDeployment(false); + + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verify(mockActionListener).onFailure(any(ElasticsearchStatusException.class)); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + public void testValidate_ModelValidatorThrowsExceptionAndModelDeploymentIsStopped() { + mockModelDeployment(true); + doThrow(new ElasticsearchStatusException("Model Validator Exception", RestStatus.INTERNAL_SERVER_ERROR)).when(mockModelValidator) + .validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + mockModelStop(true); + + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verify(mockInferenceService).stop(eq(mockModel), any()); + verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + verify(mockActionListener).delegateResponse(any()); + verifyMockActionListenerAfterStopModelDeployment(true); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + public void testValidate_ModelValidatorThrowsExceptionAndModelDeploymentIsNotStopped() { + mockModelDeployment(true); + doThrow(new ElasticsearchStatusException("Model Validator Exception", RestStatus.INTERNAL_SERVER_ERROR)).when(mockModelValidator) + .validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + mockModelStop(false); + + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verify(mockInferenceService).stop(eq(mockModel), any()); + verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + verify(mockActionListener).delegateResponse(any()); + verifyMockActionListenerAfterStopModelDeployment(false); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + public void testValidate_ModelValidationFailsAndModelDeploymentIsStopped() { + mockModelDeployment(true); + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(3); + responseListener.onFailure(new ElasticsearchStatusException("Model validation failed", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + mockModelStop(true); + + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verify(mockInferenceService).stop(eq(mockModel), any()); + verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + verify(mockActionListener).delegateResponse(any()); + verifyMockActionListenerAfterStopModelDeployment(true); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + public void testValidate_ModelValidationFailsAndModelDeploymentIsNotStopped() { + mockModelDeployment(true); + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(3); + responseListener.onFailure(new ElasticsearchStatusException("Model validation failed", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + mockModelStop(false); + + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verify(mockInferenceService).stop(eq(mockModel), any()); + verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + verify(mockActionListener).delegateResponse(any()); + verifyMockActionListenerAfterStopModelDeployment(false); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + public void testValidate_ModelValidationSucceeds() { + mockModelDeployment(true); + mockModelStop(true); + + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + verify(mockActionListener).delegateResponse(any()); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + private void mockModelDeployment(boolean modelDeploymentStarted) { + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(2); + responseListener.onResponse(modelDeploymentStarted); + return null; + }).when(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + } + + private void mockModelStop(boolean modelDeploymentStopped) { + if (modelDeploymentStopped) { + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(1); + responseListener.onResponse(null); + return null; + }).when(mockInferenceService).stop(eq(mockModel), any()); + } else { + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(1); + responseListener.onFailure(new ElasticsearchStatusException("Model stop failed", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(mockInferenceService).stop(eq(mockModel), any()); + } + } + + private void verifyMockActionListenerAfterStopModelDeployment(boolean modelDeploymentStopped) { + verify(mockInferenceService).stop(eq(mockModel), any()); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(mockActionListener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof ElasticsearchStatusException); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR, ((ElasticsearchStatusException) exceptionCaptor.getValue()).status()); + + if (modelDeploymentStopped) { + assertFalse(exceptionCaptor.getValue().getMessage().contains(MODEL_VALIDATION_AND_STOP_FAILED_MESSAGE)); + } else { + assertTrue(exceptionCaptor.getValue().getMessage().contains(MODEL_VALIDATION_AND_STOP_FAILED_MESSAGE)); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java index ec2b3c864f101..6faa4bd07b6aa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java @@ -69,7 +69,7 @@ public void testValidate_ServiceThrowsException() { eq(TEST_INPUT), eq(false), eq(Map.of()), - eq(InputType.INGEST), + eq(InputType.INTERNAL_INGEST), eq(TIMEOUT), any() ); @@ -104,7 +104,18 @@ private void mockSuccessfulCallToService(String query, InferenceServiceResults r responseListener.onResponse(result); return null; }).when(mockInferenceService) - .infer(eq(mockModel), eq(query), eq(TEST_INPUT), eq(false), eq(Map.of()), eq(InputType.INGEST), eq(TIMEOUT), any()); + .infer( + eq(mockModel), + eq(query), + eq(null), + eq(null), + eq(TEST_INPUT), + eq(false), + eq(Map.of()), + eq(InputType.INTERNAL_INGEST), + eq(TIMEOUT), + any() + ); underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } @@ -119,7 +130,7 @@ private void verifyCallToService(boolean withQuery) { eq(TEST_INPUT), eq(false), eq(Map.of()), - eq(InputType.INGEST), + eq(InputType.INTERNAL_INGEST), eq(TIMEOUT), any() );