diff --git a/docs/changelog/123044.yaml b/docs/changelog/123044.yaml new file mode 100644 index 0000000000000..2cb758c23edec --- /dev/null +++ b/docs/changelog/123044.yaml @@ -0,0 +1,5 @@ +pr: 123044 +summary: Adding validation to `ElasticsearchInternalService` +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index f36642ab8d627..d85acb021506a 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -162,24 +162,13 @@ void chunkedInfer( /** * Stop the model deployment. * The default action does nothing except acknowledge the request (true). - * @param unparsedModel The unparsed model configuration + * @param model The model configuration * @param listener The listener */ - default void stop(UnparsedModel unparsedModel, ActionListener listener) { + default void stop(Model model, ActionListener listener) { listener.onResponse(true); } - /** - * Optionally test the new model configuration in the inference service. - * This function should be called when the model is first created, the - * default action is to do nothing. - * @param model The new model - * @param listener The listener - */ - default void checkModelConfig(Model model, ActionListener listener) { - listener.onResponse(model); - }; - /** * Update a text embedding model's dimensions based on a provided embedding * size and set the default similarity if required. The default behaviour is to just return the model. diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 6f9a550481049..a5c34dc2f14d5 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -64,7 +64,7 @@ public void testGetServicesWithoutTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithTextEmbeddingTaskType() throws IOException { List services = getServices(TaskType.TEXT_EMBEDDING); - assertThat(services.size(), equalTo(15)); + assertThat(services.size(), equalTo(16)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -86,6 +86,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { "jinaai", "mistral", "openai", + "test_service", "text_embedding_test_service", "voyageai", "watsonxai" @@ -157,7 +158,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { List services = getServices(TaskType.SPARSE_EMBEDDING); - assertThat(services.size(), equalTo(5)); + assertThat(services.size(), equalTo(6)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -166,7 +167,14 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { } assertArrayEquals( - List.of("alibabacloud-ai-search", "elastic", "elasticsearch", "hugging_face", "test_service").toArray(), + List.of( + "alibabacloud-ai-search", + "elastic", + "elasticsearch", + "hugging_face", + "streaming_completion_test_service", + "test_service" + ).toArray(), providers ); } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 03c5c6201ce33..0a463dbabd513 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -35,6 +35,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import java.io.IOException; @@ -63,7 +64,7 @@ public TestSparseModel(String inferenceEntityId, TestServiceSettings serviceSett public static class TestInferenceService extends AbstractTestInferenceService { public static final String NAME = "test_service"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING); + private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING); public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {} @@ -114,7 +115,8 @@ public void infer( ActionListener listener ) { switch (model.getConfigurations().getTaskType()) { - case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeResults(input)); + case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeSparseEmbeddingResults(input)); + case TEXT_EMBEDDING -> listener.onResponse(makeTextEmbeddingResults(input)); default -> listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), @@ -155,7 +157,7 @@ public void chunkedInfer( } } - private SparseEmbeddingResults makeResults(List input) { + private SparseEmbeddingResults makeSparseEmbeddingResults(List input) { var embeddings = new ArrayList(); for (int i = 0; i < input.size(); i++) { var tokens = new ArrayList(); @@ -167,6 +169,18 @@ private SparseEmbeddingResults makeResults(List input) { return new SparseEmbeddingResults(embeddings); } + private TextEmbeddingFloatResults makeTextEmbeddingResults(List input) { + var embeddings = new ArrayList(); + for (int i = 0; i < input.size(); i++) { + var values = new float[5]; + for (int j = 0; j < 5; j++) { + values[j] = random.nextFloat(); + } + embeddings.add(new TextEmbeddingFloatResults.Embedding(values)); + } + return new TextEmbeddingFloatResults(embeddings); + } + private List makeChunkedResults(List inputs) { List results = new ArrayList<>(); for (ChunkInferenceInput chunkInferenceInput : inputs) { diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 2320429f20704..e34018c5b8df1 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -36,8 +36,10 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import java.io.IOException; +import java.util.ArrayList; import java.util.EnumSet; import java.util.HashMap; import java.util.Iterator; @@ -59,7 +61,11 @@ public static class TestInferenceService extends AbstractTestInferenceService { private static final String NAME = "streaming_completion_test_service"; private static final Set supportedStreamingTasks = Set.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); + private static final EnumSet supportedTaskTypes = EnumSet.of( + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION, + TaskType.SPARSE_EMBEDDING + ); public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {} @@ -115,7 +121,21 @@ public void infer( ActionListener listener ) { switch (model.getConfigurations().getTaskType()) { - case COMPLETION -> listener.onResponse(makeResults(input)); + case COMPLETION -> listener.onResponse(makeChatCompletionResults(input)); + case SPARSE_EMBEDDING -> { + if (stream) { + listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } else { + // Return text embedding results when creating a sparse_embedding inference endpoint to allow creation validation to + // pass. This is required to test that streaming fails for a sparse_embedding endpoint. + listener.onResponse(makeTextEmbeddingResults(input)); + } + } default -> listener.onFailure( new ElasticsearchStatusException( TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), @@ -143,7 +163,7 @@ public void unifiedCompletionInfer( } } - private StreamingChatCompletionResults makeResults(List input) { + private StreamingChatCompletionResults makeChatCompletionResults(List input) { var responseIter = input.stream().map(s -> s.toUpperCase(Locale.ROOT)).iterator(); return new StreamingChatCompletionResults(subscriber -> { subscriber.onSubscribe(new Flow.Subscription() { @@ -162,6 +182,18 @@ public void cancel() {} }); } + private TextEmbeddingFloatResults makeTextEmbeddingResults(List input) { + var embeddings = new ArrayList(); + for (int i = 0; i < input.size(); i++) { + var values = new float[5]; + for (int j = 0; j < 5; j++) { + values[j] = random.nextFloat(); + } + embeddings.add(new TextEmbeddingFloatResults.Embedding(values)); + } + return new TextEmbeddingFloatResults(embeddings); + } + private InferenceServiceResults.Result completionChunk(String delta) { return new InferenceServiceResults.Result() { @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index 6c9b534bdedb1..46f4babe5b7cf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -125,7 +125,9 @@ private void doExecuteForked( var service = serviceRegistry.getService(unparsedModel.service()); if (service.isPresent()) { - service.get().stop(unparsedModel, listener); + var model = service.get() + .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); + service.get().stop(model, listener); } else { listener.onFailure( new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index aa28d70b51009..4357fa619954c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -45,6 +45,7 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; +import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.io.IOException; import java.util.List; @@ -190,19 +191,23 @@ private void parseAndStoreModel( ActionListener storeModelListener = listener.delegateFailureAndWrap( (delegate, verifiedModel) -> modelRegistry.storeModel( verifiedModel, - ActionListener.wrap(r -> startInferenceEndpoint(service, timeout, verifiedModel, delegate), e -> { - if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) { - delegate.onFailure( - new ElasticsearchStatusException( - "One or more nodes in your cluster does not support chunking_settings. " - + "Please update all nodes in your cluster to the latest version to use chunking_settings.", - RestStatus.BAD_REQUEST - ) - ); - } else { - delegate.onFailure(e); + ActionListener.wrap( + r -> listener.onResponse(new PutInferenceModelAction.Response(verifiedModel.getConfigurations())), + e -> { + if (e.getCause() instanceof StrictDynamicMappingException + && e.getCause().getMessage().contains("chunking_settings")) { + delegate.onFailure( + new ElasticsearchStatusException( + "One or more nodes in your cluster does not support chunking_settings. " + + "Please update all nodes in your cluster to the latest version to use chunking_settings.", + RestStatus.BAD_REQUEST + ) + ); + } else { + delegate.onFailure(e); + } } - }), + ), timeout ) ); @@ -211,26 +216,14 @@ private void parseAndStoreModel( if (skipValidationAndStart) { storeModelListener.onResponse(model); } else { - service.checkModelConfig(model, storeModelListener); + ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service instanceof ElasticsearchInternalService) + .validate(service, model, timeout, storeModelListener); } }); service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener); } - private void startInferenceEndpoint( - InferenceService service, - TimeValue timeout, - Model model, - ActionListener listener - ) { - if (skipValidationAndStart) { - listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations())); - } else { - service.start(model, timeout, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations()))); - } - } - private Map requestToMap(PutInferenceModelAction.Request request) throws IOException { try ( XContentParser parser = XContentHelper.createParser( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index a0c77599b6ce6..bdcadb2277c2b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -8,22 +8,17 @@ package org.elasticsearch.xpack.inference.services; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; @@ -723,53 +718,6 @@ public static ElasticsearchStatusException createInvalidModelException(Model mod ); } - /** - * Evaluate the model and return the text embedding size - * @param model Should be a text embedding model - * @param service The inference service - * @param listener Size listener - */ - public static void getEmbeddingSize(Model model, InferenceService service, ActionListener listener) { - assert model.getTaskType() == TaskType.TEXT_EMBEDDING; - - service.infer( - model, - null, - null, - null, - List.of(TEST_EMBEDDING_INPUT), - false, - Map.of(), - InputType.INTERNAL_INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener.delegateFailureAndWrap((delegate, r) -> { - if (r instanceof TextEmbeddingResults embeddingResults) { - try { - delegate.onResponse(embeddingResults.getFirstEmbeddingSize()); - } catch (Exception e) { - delegate.onFailure( - new ElasticsearchStatusException("Could not determine embedding size", RestStatus.BAD_REQUEST, e) - ); - } - } else { - delegate.onFailure( - new ElasticsearchStatusException( - "Could not determine embedding size. " - + "Expected a result of type [" - + TextEmbeddingFloatResults.NAME - + "] got [" - + r.getWriteableName() - + "]", - RestStatus.BAD_REQUEST - ) - ); - } - }) - ); - } - - private static final String TEST_EMBEDDING_INPUT = "how big"; - public static SecureString apiKey(@Nullable ApiKeySecrets secrets) { // To avoid a possible null pointer throughout the code we'll create a noop api key of an empty array return secrets == null ? new SecureString(new char[0]) : secrets.apiKey(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index ac0d0df06b48d..6a659a9be2bbd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -48,7 +48,6 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -348,19 +347,6 @@ protected void doChunkedInfer( } } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof AlibabaCloudSearchEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 38d8d61873ce5..93e0033d88c65 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -47,7 +47,6 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.io.IOException; import java.util.EnumSet; @@ -334,19 +333,6 @@ public Set supportedStreamingTasks() { return COMPLETION_ONLY; } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof AmazonBedrockEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index 52b1cb36b0c92..bec8908ab73f9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -37,7 +37,6 @@ import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -177,12 +176,6 @@ public AnthropicModel parsePersistedConfig(String inferenceEntityId, TaskType ta ); } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index a70f44b91f9f8..04883f23b947f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -46,7 +46,6 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -332,12 +331,6 @@ private AzureAiStudioModel createModelFromPersistent( ); } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof AzureAiStudioEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 03778e4471042..e9ff97c1ba725 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -43,7 +43,6 @@ import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -298,19 +297,6 @@ protected void doChunkedInfer( } } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof AzureOpenAiEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 66dc7a1de9a75..bf6a0bd03122b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -47,7 +47,6 @@ import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -311,19 +310,6 @@ protected void doChunkedInfer( } } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof CohereEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java index 8eaffb7b96fdf..56719199e094f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -32,7 +32,6 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -176,12 +175,6 @@ public Set supportedStreamingTasks() { return SUPPORTED_TASK_TYPES_FOR_STREAMING; } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - private static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 75a9e44d25b62..dcaafb702eba3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -54,7 +54,6 @@ import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.EnumSet; @@ -426,12 +425,6 @@ private ElasticInferenceServiceModel createModelFromPersistent( ); } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index c9b13de4de83d..84259d1c0be66 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -22,7 +22,6 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; @@ -118,9 +117,7 @@ public void start(Model model, TimeValue timeout, ActionListener finalL } @Override - public void stop(UnparsedModel unparsedModel, ActionListener listener) { - - var model = parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); + public void stop(Model model, ActionListener listener) { if (model instanceof ElasticsearchInternalModel esModel) { var serviceSettings = esModel.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index ed331e9df658e..5e409e6ba5eb4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -561,31 +561,6 @@ private void migrateModelVersionToModelId(Map serviceSettingsMap } } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - if (model instanceof CustomElandEmbeddingModel elandModel && elandModel.getTaskType() == TaskType.TEXT_EMBEDDING) { - // At this point the inference endpoint configuration has not been persisted yet, if we attempt to do inference using the - // inference id we'll get an error because the trained model code needs to use the persisted inference endpoint to retrieve the - // model id. To get around this we'll have the getEmbeddingSize() method use the model id instead of inference id. So we need - // to create a temporary model that overrides the inference id with the model id. - var temporaryModelWithModelId = new CustomElandEmbeddingModel( - elandModel.getServiceSettings().modelId(), - elandModel.getTaskType(), - elandModel.getConfigurations().getService(), - elandModel.getServiceSettings(), - elandModel.getConfigurations().getChunkingSettings() - ); - - ServiceUtils.getEmbeddingSize( - temporaryModelWithModelId, - this, - listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(elandModel, size))) - ); - } else { - listener.onResponse(model); - } - } - private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomElandEmbeddingModel model, int embeddingSize) { CustomElandInternalTextEmbeddingServiceSettings serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings( model.getServiceSettings().getNumAllocations(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 99b5acbfc36c7..9841ea64370c3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -46,7 +46,6 @@ import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -255,12 +254,6 @@ public Set supportedStreamingTasks() { return COMPLETION_ONLY; } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof GoogleAiStudioEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 8526e8abbad4d..e966ebc8d9e9b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -42,7 +42,6 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -188,12 +187,6 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_15_0; } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 612748f6ede12..f2a53520e18e6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -37,7 +37,6 @@ import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -83,12 +82,6 @@ protected HuggingFaceModel createModel( }; } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof HuggingFaceEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index c01d4d142fe16..7dfb0002bb062 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -43,7 +43,6 @@ import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -237,12 +236,6 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_16_0; } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof IbmWatsonxEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index afd1d5db213bf..c2e88cb6cdc7c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -45,7 +45,6 @@ import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -292,18 +291,6 @@ protected void doChunkedInfer( } } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof JinaAIEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 5c6488bfbbda2..558b7e255f2b4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -43,7 +43,6 @@ import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -284,12 +283,6 @@ private MistralEmbeddingsModel createModelFromPersistent( ); } - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof MistralEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 094b6b27e158b..5ff456000f8b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -50,7 +50,6 @@ import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.EnumSet; import java.util.HashMap; @@ -351,19 +350,6 @@ protected void doChunkedInfer( } } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - // TODO: Remove this function once all services have been updated to use the new model validators - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof OpenAiEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java index b7a9fa7e6f3ab..624f223c9f3e1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.validation; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; @@ -20,8 +21,8 @@ public ChatCompletionModelValidator(ServiceIntegrationValidator serviceIntegrati } @Override - public void validate(InferenceService service, Model model, ActionListener listener) { - serviceIntegrationValidator.validate(service, model, listener.delegateFailureAndWrap((delegate, r) -> { + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { + serviceIntegrationValidator.validate(service, model, timeout, listener.delegateFailureAndWrap((delegate, r) -> { delegate.onResponse(postValidate(service, model)); })); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java new file mode 100644 index 0000000000000..aa3bb0ef9bcb0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.validation; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.Model; +import org.elasticsearch.rest.RestStatus; + +public class ElasticsearchInternalServiceModelValidator implements ModelValidator { + + ModelValidator modelValidator; + + public ElasticsearchInternalServiceModelValidator(ModelValidator modelValidator) { + this.modelValidator = modelValidator; + } + + @Override + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { + service.start(model, timeout, ActionListener.wrap((modelDeploymentStarted) -> { + if (modelDeploymentStarted) { + try { + modelValidator.validate(service, model, timeout, listener.delegateResponse((l, exception) -> { + stopModelDeployment(service, model, l, exception); + })); + } catch (Exception e) { + stopModelDeployment(service, model, listener, e); + } + } else { + listener.onFailure( + new ElasticsearchStatusException("Could not deploy model for inference endpoint", RestStatus.INTERNAL_SERVER_ERROR) + ); + } + }, listener::onFailure)); + } + + private void stopModelDeployment(InferenceService service, Model model, ActionListener listener, Exception e) { + service.stop( + model, + ActionListener.wrap( + (v) -> listener.onFailure(e), + (ex) -> listener.onFailure( + new ElasticsearchStatusException( + "Model validation failed and model deployment could not be stopped", + RestStatus.INTERNAL_SERVER_ERROR, + ex + ) + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidator.java index c435939a17568..2c7701942ecb0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidator.java @@ -8,9 +8,10 @@ package org.elasticsearch.xpack.inference.services.validation; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; public interface ModelValidator { - void validate(InferenceService service, Model model, ActionListener listener); + void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java index 1c4306c4edd46..c0088858588ea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java @@ -11,7 +11,16 @@ import org.elasticsearch.inference.TaskType; public class ModelValidatorBuilder { - public static ModelValidator buildModelValidator(TaskType taskType) { + public static ModelValidator buildModelValidator(TaskType taskType, boolean isElasticsearchInternalService) { + var modelValidator = buildModelValidatorForTaskType(taskType); + if (isElasticsearchInternalService) { + return new ElasticsearchInternalServiceModelValidator(modelValidator); + } else { + return modelValidator; + } + } + + private static ModelValidator buildModelValidatorForTaskType(TaskType taskType) { if (taskType == null) { throw new IllegalArgumentException("Task type can't be null"); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java index 09fb43f584cf0..49ade6c00fb22 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java @@ -8,10 +8,11 @@ package org.elasticsearch.xpack.inference.services.validation; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; public interface ServiceIntegrationValidator { - void validate(InferenceService service, Model model, ActionListener listener); + void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java index eb65fe33fb84a..f9cf67172bc2a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java @@ -10,11 +10,11 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import java.util.List; @@ -28,32 +28,27 @@ public class SimpleChatCompletionServiceIntegrationValidator implements ServiceI private static final List TEST_INPUT = List.of("how big"); @Override - public void validate(InferenceService service, Model model, ActionListener listener) { + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { var chatCompletionInput = new UnifiedChatInput(TEST_INPUT, USER_ROLE, false); - service.unifiedCompletionInfer( - model, - chatCompletionInput.getRequest(), - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> { - if (r != null) { - listener.onResponse(r); - } else { - listener.onFailure( - new ElasticsearchStatusException( - "Could not complete inference endpoint creation as validation call to service returned null response.", - RestStatus.BAD_REQUEST - ) - ); - } - }, e -> { + service.unifiedCompletionInfer(model, chatCompletionInput.getRequest(), timeout, ActionListener.wrap(r -> { + if (r != null) { + listener.onResponse(r); + } else { listener.onFailure( new ElasticsearchStatusException( - "Could not complete inference endpoint creation as validation call to service threw an exception.", - RestStatus.BAD_REQUEST, - e + "Could not complete inference endpoint creation as validation call to service returned null response.", + RestStatus.BAD_REQUEST ) ); - }) - ); + } + }, e -> { + listener.onFailure( + new ElasticsearchStatusException( + "Could not complete inference endpoint creation as validation call to service threw an exception.", + RestStatus.BAD_REQUEST, + e + ) + ); + })); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java index f44cf61079369..3d592840f533b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.validation; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; @@ -20,11 +21,9 @@ public SimpleModelValidator(ServiceIntegrationValidator serviceIntegrationValida } @Override - public void validate(InferenceService service, Model model, ActionListener listener) { - serviceIntegrationValidator.validate( - service, - model, - listener.delegateFailureAndWrap((delegate, r) -> { delegate.onResponse(model); }) - ); + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { + serviceIntegrationValidator.validate(service, model, timeout, listener.delegateFailureAndWrap((delegate, r) -> { + delegate.onResponse(model); + })); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java index 4c48e3018b956..03ac5b95fddc5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java @@ -10,13 +10,13 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import java.util.List; import java.util.Map; @@ -26,7 +26,7 @@ public class SimpleServiceIntegrationValidator implements ServiceIntegrationVali private static final String QUERY = "test query"; @Override - public void validate(InferenceService service, Model model, ActionListener listener) { + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { service.infer( model, model.getTaskType().equals(TaskType.RERANK) ? QUERY : null, @@ -36,7 +36,7 @@ public void validate(InferenceService service, Model model, ActionListener { if (r != null) { listener.onResponse(r); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java index c82d6c00c3616..bff04f5af2d75 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Strings; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; @@ -26,8 +27,8 @@ public TextEmbeddingModelValidator(ServiceIntegrationValidator serviceIntegratio } @Override - public void validate(InferenceService service, Model model, ActionListener listener) { - serviceIntegrationValidator.validate(service, model, listener.delegateFailureAndWrap((delegate, r) -> { + public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener listener) { + serviceIntegrationValidator.validate(service, model, timeout, listener.delegateFailureAndWrap((delegate, r) -> { delegate.onResponse(postValidate(service, model, r)); })); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 229266a5e51ed..0ffec057dc2b4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -40,7 +40,6 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import org.elasticsearch.xpack.inference.services.voyageai.action.VoyageAIActionCreator; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsModel; import org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsServiceSettings; @@ -316,18 +315,6 @@ private static int getBatchSize(VoyageAIModel model) { return MODEL_BATCH_SIZES.getOrDefault(model.getServiceSettings().modelId(), DEFAULT_BATCH_SIZE); } - /** - * For text embedding models get the embedding size and - * update the service settings. - * - * @param model The new model - * @param listener The listener - */ - @Override - public void checkModelConfig(Model model, ActionListener listener) { - ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); - } - @Override public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof VoyageAIEmbeddingsModel embeddingsModel) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index 190520fbc3b68..6b2731bb313b5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -8,21 +8,11 @@ package org.elasticsearch.xpack.inference.services; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResultsTests; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import java.util.EnumSet; @@ -41,23 +31,14 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveIntegerLessThanOrEqualToMax; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.getEmbeddingSize; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; public class ServiceUtilsTests extends ESTestCase { - - private static final TimeValue TIMEOUT = TimeValue.timeValueSeconds(30); - public void testRemoveAsTypeWithTheCorrectType() { Map map = new HashMap<>(Map.of("a", 5, "b", "a string", "c", Boolean.TRUE, "d", 1.0)); @@ -903,96 +884,6 @@ public void testExtractRequiredEnum_HasValidationErrorOnMissingSetting() { assertThat(validationException.validationErrors().get(0), is("[testscope] does not contain the required setting [missing_key]")); } - public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingResults_IsEmpty() { - var service = mock(InferenceService.class); - - var model = mock(Model.class); - when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(9); - listener.onResponse(new TextEmbeddingFloatResults(List.of())); - - return Void.TYPE; - }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); - - PlainActionFuture listener = new PlainActionFuture<>(); - getEmbeddingSize(model, service, listener); - - var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - - assertThat(thrownException.getMessage(), is("Could not determine embedding size")); - assertThat(thrownException.getCause().getMessage(), is("Embeddings list is empty")); - } - - public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingByteResults_IsEmpty() { - var service = mock(InferenceService.class); - - var model = mock(Model.class); - when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(9); - listener.onResponse(new TextEmbeddingByteResults(List.of())); - - return Void.TYPE; - }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); - - PlainActionFuture listener = new PlainActionFuture<>(); - getEmbeddingSize(model, service, listener); - - var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - - assertThat(thrownException.getMessage(), is("Could not determine embedding size")); - assertThat(thrownException.getCause().getMessage(), is("Embeddings list is empty")); - } - - public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingResults() { - var service = mock(InferenceService.class); - - var model = mock(Model.class); - when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); - - var textEmbedding = TextEmbeddingFloatResultsTests.createRandomResults(); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(9); - listener.onResponse(textEmbedding); - - return Void.TYPE; - }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); - - PlainActionFuture listener = new PlainActionFuture<>(); - getEmbeddingSize(model, service, listener); - - var size = listener.actionGet(TIMEOUT); - - assertThat(size, is(textEmbedding.embeddings().get(0).values().length)); - } - - public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingByteResults() { - var service = mock(InferenceService.class); - - var model = mock(Model.class); - when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); - - var textEmbedding = TextEmbeddingByteResultsTests.createRandomResults(); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(9); - listener.onResponse(textEmbedding); - - return Void.TYPE; - }).when(service).infer(any(), any(), any(), any(), any(), anyBoolean(), any(), any(), any(), any()); - - PlainActionFuture listener = new PlainActionFuture<>(); - getEmbeddingSize(model, service, listener); - - var size = listener.actionGet(TIMEOUT); - - assertThat(size, is(textEmbedding.embeddings().get(0).values().length)); - } - public void testValidateInputType_NoValidationErrorsWhenInternalType() { ValidationException validationException = new ValidationException(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index a3acfbcfee35d..3a817da2a6a4d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -41,7 +41,6 @@ import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceFields; @@ -52,7 +51,6 @@ import org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettingsTests; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.sparse.AlibabaCloudSearchSparseModel; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; -import org.hamcrest.MatcherAssert; import org.junit.After; import org.junit.Before; @@ -261,71 +259,6 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChun } } - public void testCheckModelConfig() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool)) { - @Override - public void doInfer( - Model model, - InferenceInputs inputs, - Map taskSettings, - TimeValue timeout, - ActionListener listener - ) { - TextEmbeddingFloatResults results = new TextEmbeddingFloatResults( - List.of(new TextEmbeddingFloatResults.Embedding(new float[] { -0.028680f, 0.022033f })) - ); - - listener.onResponse(results); - } - }) { - Map serviceSettingsMap = new HashMap<>(); - serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); - serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); - serviceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); - serviceSettingsMap.put(ServiceFields.DIMENSIONS, 1536); - - Map taskSettingsMap = new HashMap<>(); - - Map secretSettingsMap = new HashMap<>(); - secretSettingsMap.put("api_key", "secret"); - - var model = AlibabaCloudSearchEmbeddingsModelTests.createModel( - "service", - TaskType.TEXT_EMBEDDING, - serviceSettingsMap, - taskSettingsMap, - secretSettingsMap - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - Map expectedServiceSettingsMap = new HashMap<>(); - expectedServiceSettingsMap.put(AlibabaCloudSearchServiceSettings.SERVICE_ID, "service_id"); - expectedServiceSettingsMap.put(AlibabaCloudSearchServiceSettings.HOST, "host"); - expectedServiceSettingsMap.put(AlibabaCloudSearchServiceSettings.WORKSPACE_NAME, "default"); - expectedServiceSettingsMap.put(ServiceFields.SIMILARITY, "DOT_PRODUCT"); - expectedServiceSettingsMap.put(ServiceFields.DIMENSIONS, 2); - - Map expectedTaskSettingsMap = new HashMap<>(); - - Map expectedSecretSettingsMap = new HashMap<>(); - expectedSecretSettingsMap.put("api_key", "secret"); - - var expectedModel = AlibabaCloudSearchEmbeddingsModelTests.createModel( - "service", - TaskType.TEXT_EMBEDDING, - expectedServiceSettingsMap, - expectedTaskSettingsMap, - expectedSecretSettingsMap - ); - - MatcherAssert.assertThat(result, is(expectedModel)); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 8d4ce151605a5..688bd3d4afc56 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -53,7 +53,6 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -1095,232 +1094,6 @@ public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { } } - public void testCheckModelConfig_IncludesMaxTokens_ForEmbeddingsModel() throws IOException { - var sender = mock(Sender.class); - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( - ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), - mockClusterServiceEmpty() - ); - - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { - try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { - var results = new TextEmbeddingFloatResults( - List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) - ); - requestSender.enqueue(results); - - var model = AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - null, - false, - 100, - null, - null, - "access", - "secret" - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - 2, - false, - 100, - SimilarityMeasure.COSINE, - null, - "access", - "secret" - ) - ) - ); - var inputStrings = requestSender.getInputs(); - - MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); - } - } - } - - public void testCheckModelConfig_HasSimilarity_ForEmbeddingsModel() throws IOException { - var sender = mock(Sender.class); - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( - ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), - mockClusterServiceEmpty() - ); - - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { - try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { - var results = new TextEmbeddingFloatResults( - List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) - ); - requestSender.enqueue(results); - - var model = AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - null, - false, - null, - SimilarityMeasure.COSINE, - null, - "access", - "secret" - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - 2, - false, - null, - SimilarityMeasure.COSINE, - null, - "access", - "secret" - ) - ) - ); - var inputStrings = requestSender.getInputs(); - - MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); - } - } - } - - public void testCheckModelConfig_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException { - var sender = mock(Sender.class); - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( - ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), - mockClusterServiceEmpty() - ); - - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { - try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { - var results = new TextEmbeddingFloatResults( - List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) - ); - requestSender.enqueue(results); - - var model = AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - 3, - true, - null, - null, - null, - "access", - "secret" - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - exception.getMessage(), - is( - "The retrieved embeddings size [2] does not match the size specified in the settings [3]. " - + "Please recreate the [id] configuration with the correct dimensions" - ) - ); - - var inputStrings = requestSender.getInputs(); - MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); - } - } - } - - public void testCheckModelConfig_ReturnsNewModelReference_AndDoesNotSendDimensionsField_WhenNotSetByUser() throws IOException { - var sender = mock(Sender.class); - var factory = mock(HttpRequestSender.Factory.class); - when(factory.createSender()).thenReturn(sender); - - var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( - ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), - mockClusterServiceEmpty() - ); - - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { - try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { - var results = new TextEmbeddingFloatResults( - List.of(new TextEmbeddingFloatResults.Embedding(new float[] { 0.123F, 0.678F })) - ); - requestSender.enqueue(results); - - var model = AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - 100, - false, - null, - SimilarityMeasure.COSINE, - null, - "access", - "secret" - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AmazonBedrockEmbeddingsModelTests.createModel( - "id", - "region", - "model", - AmazonBedrockProvider.AMAZONTITAN, - 2, - false, - null, - SimilarityMeasure.COSINE, - null, - "access", - "secret" - ) - ) - ); - var inputStrings = requestSender.getInputs(); - - MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); - } - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var sender = mock(Sender.class); var factory = mock(HttpRequestSender.Factory.class); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index c688803b43ff1..fb13f2644cebd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -56,7 +56,6 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettingsTests; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -843,141 +842,6 @@ public void testParsePersistedConfig_WithoutSecretsCreatesChatCompletionModel() } } - public void testCheckModelConfig_ForEmbeddingsModel_Works() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingResultJson)); - - var model = AzureAiStudioEmbeddingsModelTests.createModel( - "id", - getUrl(webServer), - AzureAiStudioProvider.OPENAI, - AzureAiStudioEndpointType.TOKEN, - "apikey", - null, - false, - null, - null, - null, - null - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AzureAiStudioEmbeddingsModelTests.createModel( - "id", - getUrl(webServer), - AzureAiStudioProvider.OPENAI, - AzureAiStudioEndpointType.TOKEN, - "apikey", - 2, - false, - null, - SimilarityMeasure.DOT_PRODUCT, - null, - null - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "input_type", "document"))); - } - } - - public void testCheckModelConfig_ForEmbeddingsModel_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingResultJson)); - - var model = AzureAiStudioEmbeddingsModelTests.createModel( - "id", - getUrl(webServer), - AzureAiStudioProvider.OPENAI, - AzureAiStudioEndpointType.TOKEN, - "apikey", - 3, - true, - null, - null, - null, - null - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - exception.getMessage(), - is( - "The retrieved embeddings size [2] does not match the size specified in the settings [3]. " - + "Please recreate the [id] configuration with the correct dimensions" - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( - requestMap, - Matchers.is(Map.of("input", List.of("how big"), "dimensions", 3, "input_type", "document")) - ); - } - } - - public void testCheckModelConfig_WorksForChatCompletionsModel() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testChatCompletionResultJson)); - - var model = AzureAiStudioChatCompletionModelTests.createModel( - "id", - getUrl(webServer), - AzureAiStudioProvider.OPENAI, - AzureAiStudioEndpointType.TOKEN, - "apikey", - null, - null, - null, - null, - null - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AzureAiStudioChatCompletionModelTests.createModel( - "id", - getUrl(webServer), - AzureAiStudioProvider.OPENAI, - AzureAiStudioEndpointType.TOKEN, - "apikey", - null, - null, - null, - AzureAiStudioChatCompletionTaskSettings.DEFAULT_MAX_NEW_TOKENS, - null - ) - ) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index facdb90873d47..36cc6cb051b46 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -48,7 +48,6 @@ import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests; import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -851,376 +850,6 @@ public void testInfer_SendsRequest() throws IOException, URISyntaxException { } } - public void testCheckModelConfig_IncludesMaxTokens() throws IOException, URISyntaxException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - null, - false, - 100, - null, - "user", - "apikey", - null, - "id" - ); - model.setUri(new URI(getUrl(webServer))); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - 2, - false, - 100, - SimilarityMeasure.DOT_PRODUCT, - "user", - "apikey", - null, - "id" - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - // service.checkModelConfig validates by performing inference with InputType.INTERNAL_INGEST - MatcherAssert.assertThat( - requestMap, - Matchers.is(Map.of("input", List.of("how big"), "user", "user", "input_type", "internal_ingest")) - ); - } - } - - public void testCheckModelConfig_HasSimilarity() throws IOException, URISyntaxException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - null, - false, - null, - SimilarityMeasure.COSINE, - "user", - "apikey", - null, - "id" - ); - model.setUri(new URI(getUrl(webServer))); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - 2, - false, - null, - SimilarityMeasure.COSINE, - "user", - "apikey", - null, - "id" - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - // service.checkModelConfig validates by performing inference with InputType.INTERNAL_INGEST - MatcherAssert.assertThat( - requestMap, - Matchers.is(Map.of("input", List.of("how big"), "user", "user", "input_type", "internal_ingest")) - ); - } - } - - public void testCheckModelConfig_AddsDefaultSimilarityDotProduct() throws IOException, URISyntaxException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - null, - false, - null, - null, - "user", - "apikey", - null, - "id" - ); - model.setUri(new URI(getUrl(webServer))); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - 2, - false, - null, - SimilarityMeasure.DOT_PRODUCT, - "user", - "apikey", - null, - "id" - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - // service.checkModelConfig validates by performing inference with InputType.INTERNAL_INGEST - MatcherAssert.assertThat( - requestMap, - Matchers.is(Map.of("input", List.of("how big"), "user", "user", "input_type", "internal_ingest")) - ); - } - } - - public void testCheckModelConfig_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException, URISyntaxException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - 3, - true, - 100, - null, - "user", - "apikey", - null, - "id" - ); - model.setUri(new URI(getUrl(webServer))); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - exception.getMessage(), - is( - "The retrieved embeddings size [2] does not match the size specified in the settings [3]. " - + "Please recreate the [id] configuration with the correct dimensions" - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - // service.checkModelConfig validates by performing inference with InputType.INTERNAL_INGEST - MatcherAssert.assertThat( - requestMap, - Matchers.is(Map.of("input", List.of("how big"), "user", "user", "dimensions", 3, "input_type", "internal_ingest")) - ); - } - } - - public void testCheckModelConfig_ReturnsNewModelReference_AndDoesNotSendDimensionsField_WhenNotSetByUser() throws IOException, - URISyntaxException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - 100, - false, - 100, - null, - "user", - "apikey", - null, - "id" - ); - model.setUri(new URI(getUrl(webServer))); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is( - AzureOpenAiEmbeddingsModelTests.createModel( - "resource", - "deployment", - "apiversion", - 2, - false, - 100, - SimilarityMeasure.DOT_PRODUCT, - "user", - "apikey", - null, - "id" - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - // service.checkModelConfig validates by performing inference with InputType.INTERNAL_INGEST - MatcherAssert.assertThat( - requestMap, - Matchers.is(Map.of("input", List.of("how big"), "user", "user", "input_type", "internal_ingest")) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 1d0d921956b73..469e9e55c695f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -888,198 +888,6 @@ public void testInfer_SendsRequest() throws IOException { } } - public void testCheckModelConfig_UpdatesDimensions() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], - "embeddings": { - "float": [ - [ - 0.123, - -0.123 - ] - ] - }, - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 1 - } - }, - "response_type": "embeddings_by_type" - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = CohereEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - null, - null - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - CohereEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - null, - null - ) - ) - ); - } - } - - public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], - "embeddings": { - "float": [ - [ - 0.123, - -0.123 - ] - ] - }, - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 1 - } - }, - "response_type": "embeddings_by_type" - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = CohereEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - null, - null - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - CohereEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - null, - null, - SimilarityMeasure.DOT_PRODUCT - ) - ) - ); - } - } - - public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", - "texts": [ - "hello" - ], - "embeddings": { - "float": [ - [ - 0.123, - -0.123 - ] - ] - }, - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 1 - } - }, - "response_type": "embeddings_by_type" - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = CohereEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - null, - null, - SimilarityMeasure.COSINE - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - CohereEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - null, - null, - SimilarityMeasure.COSINE - ) - ) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 6a93b1cc19c87..610b50dc37a19 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -330,32 +330,6 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } - public void testCheckModelConfig_ReturnsNewModelReference() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = createService(senderFactory, getUrl(webServer))) { - String responseJson = """ - { - "data": [ - { - "hello": 2.1259406, - "greet": 1.7073475 - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id"); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var returnedModel = listener.actionGet(TIMEOUT); - assertThat(returnedModel, is(ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id"))); - } - } - public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException { var sender = mock(Sender.class); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 7067577b30189..0163117ffcd6e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -1490,67 +1490,6 @@ public void onFailure(Exception e) { assertEquals("text_field", putConfig.getInput().getFieldNames().get(0)); } - public void testParseRequestConfigEland_SetsDimensionsToOne() { - var client = mock(Client.class); - doAnswer(invocationOnMock -> { - @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocationOnMock - .getArguments()[2]; - listener.onResponse( - new InferModelAction.Response(List.of(new MlTextEmbeddingResults("field", new double[] { 0.1 }, false)), "id", true) - ); - - var request = (InferModelAction.Request) invocationOnMock.getArguments()[1]; - assertThat(request.getId(), is("custom-model")); - return Void.TYPE; - }).when(client).execute(eq(InferModelAction.INSTANCE), any(), any()); - when(client.threadPool()).thenReturn(threadPool); - - var service = createService(client); - - var serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings( - 1, - 4, - "custom-model", - null, - null, - 1, - SimilarityMeasure.COSINE, - DenseVectorFieldMapper.ElementType.FLOAT - ); - var taskType = TaskType.TEXT_EMBEDDING; - var expectedModel = new CustomElandEmbeddingModel( - randomInferenceEntityId, - taskType, - ElasticsearchInternalService.NAME, - serviceSettings, - null - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig( - new CustomElandEmbeddingModel( - randomInferenceEntityId, - taskType, - ElasticsearchInternalService.NAME, - new CustomElandInternalTextEmbeddingServiceSettings( - 1, - 4, - "custom-model", - null, - null, - null, - SimilarityMeasure.COSINE, - DenseVectorFieldMapper.ElementType.FLOAT - ), - null - ), - listener - ); - var model = listener.actionGet(TimeValue.THIRTY_SECONDS); - assertThat(model, is(expectedModel)); - } - public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() { { assertFalse( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index f7e228cb3044c..4581c23563e0b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -1031,132 +1031,6 @@ public void testInfer_ResourceNotFound() throws IOException { } } - public void testCheckModelConfig_UpdatesDimensions() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var similarityMeasure = SimilarityMeasure.DOT_PRODUCT; - var modelId = "model"; - var apiKey = "apiKey"; - - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "embeddings": [ - { - "values": [ - 0.0123, - -0.0123 - ] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = GoogleAiStudioEmbeddingsModelTests.createModel(getUrl(webServer), modelId, apiKey, 1, similarityMeasure); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - // Updates dimensions to two as two embeddings were returned instead of one as specified before - assertThat( - result, - is(GoogleAiStudioEmbeddingsModelTests.createModel(getUrl(webServer), modelId, apiKey, 2, similarityMeasure)) - ); - } - } - - public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var oneDimension = 1; - var modelId = "model"; - var apiKey = "apiKey"; - - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "embeddings": [ - { - "values": [ - 0.0123 - ] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = GoogleAiStudioEmbeddingsModelTests.createModel(getUrl(webServer), modelId, apiKey, oneDimension, null); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - assertThat( - result, - is( - GoogleAiStudioEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - apiKey, - oneDimension, - SimilarityMeasure.DOT_PRODUCT - ) - ) - ); - } - } - - public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var oneDimension = 1; - var modelId = "model"; - var apiKey = "apiKey"; - - try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "embeddings": [ - { - "values": [ - 0.0123 - ] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = GoogleAiStudioEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - apiKey, - oneDimension, - SimilarityMeasure.COSINE - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - assertThat( - result, - is( - GoogleAiStudioEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - apiKey, - oneDimension, - SimilarityMeasure.COSINE - ) - ) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 1b4754f25e59a..b50b821d66359 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -666,90 +666,6 @@ public void testInfer_SendsElserRequest() throws IOException { } } - public void testCheckModelConfig_IncludesMaxTokens() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "embeddings": [ - [ - -0.0123 - ] - ] - { - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.DOT_PRODUCT); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.DOT_PRODUCT)) - ); - } - } - - public void testCheckModelConfig_UsesUserSpecifiedSimilarity() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "embeddings": [ - [ - -0.0123 - ] - ] - { - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 2, SimilarityMeasure.COSINE); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.COSINE)) - ); - } - } - - public void testCheckModelConfig_DefaultsSimilarityToCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "embeddings": [ - [ - -0.0123 - ] - ] - { - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 2, null); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is(HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1, 1, SimilarityMeasure.COSINE)) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 6966ff78575b5..27db8644f2952 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -863,169 +863,6 @@ public void testInfer_ResourceNotFound() throws IOException { } } - public void testCheckModelConfig_UpdatesDimensions() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var similarityMeasure = SimilarityMeasure.DOT_PRODUCT; - - try (var service = new IbmWatsonxServiceWithoutAuth(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - 1, - similarityMeasure - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - // Updates dimensions to two as two embeddings were returned instead of one as specified before - assertThat( - result, - is( - IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - 2, - similarityMeasure - ) - ) - ); - } - } - - public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var twoDimension = 2; - - try (var service = new IbmWatsonxServiceWithoutAuth(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - null - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - assertThat( - result, - is( - IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - SimilarityMeasure.DOT_PRODUCT - ) - ) - ); - } - } - - public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var twoDimension = 2; - - try (var service = new IbmWatsonxServiceWithoutAuth(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - SimilarityMeasure.COSINE - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - assertThat( - result, - is( - IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - SimilarityMeasure.COSINE - ) - ) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index 47937629fca3c..b4a39be58b245 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -808,184 +808,6 @@ public void testInfer_ThrowsErrorWhenModelIsNotJinaAIModel() throws IOException verifyNoMoreInteractions(sender); } - public void testCheckModelConfig_UpdatesDimensions() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "model": "jina-clip-v2", - "object": "list", - "usage": { - "total_tokens": 5, - "prompt_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - "jina-clip-v2", - JinaAIEmbeddingType.FLOAT - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - "jina-clip-v2", - JinaAIEmbeddingType.FLOAT - ) - ) - ); - } - } - - public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "model": "jina-clip-v2", - "object": "list", - "usage": { - "total_tokens": 5, - "prompt_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - "jina-clip-v2", - null, - JinaAIEmbeddingType.FLOAT - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - "jina-clip-v2", - SimilarityMeasure.DOT_PRODUCT, - JinaAIEmbeddingType.FLOAT - ) - ) - ); - } - } - - public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new JinaAIService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "model": "jina-clip-v2", - "object": "list", - "usage": { - "total_tokens": 5, - "prompt_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - "jina-clip-v2", - SimilarityMeasure.COSINE, - JinaAIEmbeddingType.FLOAT - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - JinaAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - JinaAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - "jina-clip-v2", - SimilarityMeasure.COSINE, - JinaAIEmbeddingType.FLOAT - ) - ) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { testUpdateModelWithEmbeddingDetails_Successful(null); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index cca1627e767c4..1b9bb447b2e60 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -49,7 +49,6 @@ import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -504,34 +503,6 @@ public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenC } } - public void testCheckModelConfig_ForEmbeddingsModel_Works() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingResultJson)); - - var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); - model.setURI(getUrl(webServer)); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat( - result, - is(MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", 2, null, SimilarityMeasure.DOT_PRODUCT, null)) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( - requestMap, - Matchers.is(Map.of("input", List.of("how big"), "encoding_format", "float", "model", "mistral-embed")) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 3229a750dee42..f727fee44cea1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -56,7 +56,6 @@ import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests; import org.hamcrest.CoreMatchers; -import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; @@ -66,7 +65,6 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -1351,393 +1349,6 @@ public void testSupportsStreaming() throws IOException { } } - public void testCheckModelConfig_IncludesMaxTokens() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", 100); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var result = listener.actionGet(TIMEOUT); - assertThat(result, is(OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", 100, 2))); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user"))); - } - } - - public void testCheckModelConfig_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", null, 100, 3, true); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); - assertThat( - exception.getMessage(), - is( - "The retrieved embeddings size [2] does not match the size specified in the settings [3]. " - + "Please recreate the [id] configuration with the correct dimensions" - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( - requestMap, - Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user", "dimensions", 3)) - ); - } - } - - public void testCheckModelConfig_ReturnsModelWithDimensionsSetTo2_AndDocProductSet_IfDimensionsSetByUser_ButSetToNull() - throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", null, 100, null, true); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var returnedModel = listener.actionGet(TIMEOUT); - assertThat( - returnedModel, - is( - OpenAiEmbeddingsModelTests.createModel( - getUrl(webServer), - "org", - "secret", - "model", - "user", - SimilarityMeasure.DOT_PRODUCT, - 100, - 2, - true - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - // since dimensions were null they should not be sent in the request - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user"))); - } - } - - public void testCheckModelConfig_ReturnsModelWithSameDimensions_AndDocProductSet_IfDimensionsSetByUser_AndTheyMatchReturnedSize() - throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", null, 100, 2, true); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var returnedModel = listener.actionGet(TIMEOUT); - assertThat( - returnedModel, - is( - OpenAiEmbeddingsModelTests.createModel( - getUrl(webServer), - "org", - "secret", - "model", - "user", - SimilarityMeasure.DOT_PRODUCT, - 100, - 2, - true - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat( - requestMap, - Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user", "dimensions", 2)) - ); - } - } - - public void testCheckModelConfig_ReturnsNewModelReference_AndDoesNotSendDimensionsField_WhenNotSetByUser() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", null, 100, 100, false); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var returnedModel = listener.actionGet(TIMEOUT); - assertThat( - returnedModel, - is( - OpenAiEmbeddingsModelTests.createModel( - getUrl(webServer), - "org", - "secret", - "model", - "user", - SimilarityMeasure.DOT_PRODUCT, - 100, - 2, - false - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user"))); - } - } - - public void testCheckModelConfig_ReturnsNewModelReference_SetsSimilarityToDocProduct_WhenNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel(getUrl(webServer), "org", "secret", "model", "user", null, 100, 100, false); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var returnedModel = listener.actionGet(TIMEOUT); - assertThat( - returnedModel, - is( - OpenAiEmbeddingsModelTests.createModel( - getUrl(webServer), - "org", - "secret", - "model", - "user", - SimilarityMeasure.DOT_PRODUCT, - 100, - 2, - false - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user"))); - } - } - - public void testCheckModelConfig_ReturnsNewModelReference_DoesNotOverrideSimilarity_WhenNotNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = OpenAiEmbeddingsModelTests.createModel( - getUrl(webServer), - "org", - "secret", - "model", - "user", - SimilarityMeasure.COSINE, - 100, - 100, - false - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - - var returnedModel = listener.actionGet(TIMEOUT); - assertThat( - returnedModel, - is( - OpenAiEmbeddingsModelTests.createModel( - getUrl(webServer), - "org", - "secret", - "model", - "user", - SimilarityMeasure.COSINE, - 100, - 2, - false - ) - ) - ); - - assertThat(webServer.requests(), hasSize(1)); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "model", "model", "user", "user"))); - } - } - public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { try (var service = createOpenAiService()) { var model = createCompletionModel( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java index 89ab07d25e83d..bd52bdd52ab3e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; @@ -26,6 +27,9 @@ import static org.mockito.MockitoAnnotations.openMocks; public class ChatCompletionModelValidatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; + @Mock private ServiceIntegrationValidator mockServiceIntegrationValidator; @Mock @@ -48,14 +52,14 @@ public void setup() { public void testValidate_ServiceIntegrationValidatorThrowsException() { doThrow(ElasticsearchStatusException.class).when(mockServiceIntegrationValidator) - .validate(eq(mockInferenceService), eq(mockModel), any()); + .validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); assertThrows( ElasticsearchStatusException.class, - () -> { underTest.validate(mockInferenceService, mockModel, mockActionListener); } + () -> { underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } ); - verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); verify(mockActionListener).delegateFailureAndWrap(any()); verifyNoMoreInteractions( mockServiceIntegrationValidator, @@ -70,14 +74,14 @@ public void testValidate_ChatCompletionDetailsUpdated() { when(mockActionListener.delegateFailureAndWrap(any())).thenCallRealMethod(); when(mockInferenceService.updateModelWithChatCompletionDetails(mockModel)).thenReturn(mockModel); doAnswer(ans -> { - ActionListener responseListener = ans.getArgument(2); + ActionListener responseListener = ans.getArgument(3); responseListener.onResponse(mockInferenceServiceResults); return null; - }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); - underTest.validate(mockInferenceService, mockModel, mockActionListener); + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); - verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); verify(mockActionListener).delegateFailureAndWrap(any()); verify(mockActionListener).onResponse(mockModel); verify(mockInferenceService).updateModelWithChatCompletionDetails(mockModel); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java new file mode 100644 index 0000000000000..d583d2d7f5661 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java @@ -0,0 +1,197 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.validation; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.Model; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.openMocks; + +public class ElasticsearchInternalServiceModelValidatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; + private static final String MODEL_VALIDATION_AND_STOP_FAILED_MESSAGE = + "Model validation failed and model deployment could not be stopped"; + + @Mock + private ModelValidator mockModelValidator; + @Mock + private InferenceService mockInferenceService; + @Mock + private Model mockModel; + @Mock + private ActionListener mockActionListener; + + private ElasticsearchInternalServiceModelValidator underTest; + + @Before + public void setup() { + openMocks(this); + + underTest = new ElasticsearchInternalServiceModelValidator(mockModelValidator); + + when(mockActionListener.delegateResponse(any())).thenCallRealMethod(); + } + + public void testValidate_ModelDeploymentThrowsException() { + doThrow(ElasticsearchStatusException.class).when(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + + assertThrows( + ElasticsearchStatusException.class, + () -> { underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } + ); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + public void testValidate_ModelDeploymentReturnsFalse() { + mockModelDeployment(false); + + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verify(mockActionListener).onFailure(any(ElasticsearchStatusException.class)); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + public void testValidate_ModelValidatorThrowsExceptionAndModelDeploymentIsStopped() { + mockModelDeployment(true); + doThrow(new ElasticsearchStatusException("Model Validator Exception", RestStatus.INTERNAL_SERVER_ERROR)).when(mockModelValidator) + .validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + mockModelStop(true); + + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verify(mockInferenceService).stop(eq(mockModel), any()); + verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + verify(mockActionListener).delegateResponse(any()); + verifyMockActionListenerAfterStopModelDeployment(true); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + public void testValidate_ModelValidatorThrowsExceptionAndModelDeploymentIsNotStopped() { + mockModelDeployment(true); + doThrow(new ElasticsearchStatusException("Model Validator Exception", RestStatus.INTERNAL_SERVER_ERROR)).when(mockModelValidator) + .validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + mockModelStop(false); + + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verify(mockInferenceService).stop(eq(mockModel), any()); + verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + verify(mockActionListener).delegateResponse(any()); + verifyMockActionListenerAfterStopModelDeployment(false); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + public void testValidate_ModelValidationFailsAndModelDeploymentIsStopped() { + mockModelDeployment(true); + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(3); + responseListener.onFailure(new ElasticsearchStatusException("Model validation failed", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + mockModelStop(true); + + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verify(mockInferenceService).stop(eq(mockModel), any()); + verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + verify(mockActionListener).delegateResponse(any()); + verifyMockActionListenerAfterStopModelDeployment(true); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + public void testValidate_ModelValidationFailsAndModelDeploymentIsNotStopped() { + mockModelDeployment(true); + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(3); + responseListener.onFailure(new ElasticsearchStatusException("Model validation failed", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + mockModelStop(false); + + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verify(mockInferenceService).stop(eq(mockModel), any()); + verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + verify(mockActionListener).delegateResponse(any()); + verifyMockActionListenerAfterStopModelDeployment(false); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + public void testValidate_ModelValidationSucceeds() { + mockModelDeployment(true); + mockModelStop(true); + + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); + + verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); + verify(mockActionListener).delegateResponse(any()); + verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener); + } + + private void mockModelDeployment(boolean modelDeploymentStarted) { + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(2); + responseListener.onResponse(modelDeploymentStarted); + return null; + }).when(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any()); + } + + private void mockModelStop(boolean modelDeploymentStopped) { + if (modelDeploymentStopped) { + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(1); + responseListener.onResponse(null); + return null; + }).when(mockInferenceService).stop(eq(mockModel), any()); + } else { + doAnswer(ans -> { + ActionListener responseListener = ans.getArgument(1); + responseListener.onFailure(new ElasticsearchStatusException("Model stop failed", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(mockInferenceService).stop(eq(mockModel), any()); + } + } + + private void verifyMockActionListenerAfterStopModelDeployment(boolean modelDeploymentStopped) { + verify(mockInferenceService).stop(eq(mockModel), any()); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(mockActionListener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof ElasticsearchStatusException); + assertEquals(RestStatus.INTERNAL_SERVER_ERROR, ((ElasticsearchStatusException) exceptionCaptor.getValue()).status()); + + if (modelDeploymentStopped) { + assertFalse(exceptionCaptor.getValue().getMessage().contains(MODEL_VALIDATION_AND_STOP_FAILED_MESSAGE)); + } else { + assertTrue(exceptionCaptor.getValue().getMessage().contains(MODEL_VALIDATION_AND_STOP_FAILED_MESSAGE)); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java index a854bbdec507a..19ea0bedaaea5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java @@ -16,12 +16,12 @@ public class ModelValidatorBuilderTests extends ESTestCase { public void testBuildModelValidator_NullTaskType() { - assertThrows(IllegalArgumentException.class, () -> { ModelValidatorBuilder.buildModelValidator(null); }); + assertThrows(IllegalArgumentException.class, () -> { ModelValidatorBuilder.buildModelValidator(null, false); }); } public void testBuildModelValidator_ValidTaskType() { taskTypeToModelValidatorClassMap().forEach((taskType, modelValidatorClass) -> { - assertThat(ModelValidatorBuilder.buildModelValidator(taskType), isA(modelValidatorClass)); + assertThat(ModelValidatorBuilder.buildModelValidator(taskType, false), isA(modelValidatorClass)); }); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidatorTests.java index f02c4662d49e4..1103a26c45b2e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidatorTests.java @@ -9,13 +9,13 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -44,6 +44,7 @@ public class SimpleChatCompletionServiceIntegrationValidatorTests extends ESTest null, null ); + private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; @Mock private InferenceService mockInferenceService; @@ -67,9 +68,12 @@ public void setup() { public void testValidate_ServiceThrowsException() { doThrow(ElasticsearchStatusException.class).when(mockInferenceService) - .unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(InferenceAction.Request.DEFAULT_TIMEOUT), any()); + .unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); - assertThrows(ElasticsearchStatusException.class, () -> underTest.validate(mockInferenceService, mockModel, mockActionListener)); + assertThrows( + ElasticsearchStatusException.class, + () -> underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener) + ); verifyCallToService(); } @@ -117,10 +121,9 @@ private void mockSuccessfulCallToService(InferenceServiceResults result) { ActionListener responseListener = ans.getArgument(3); responseListener.onResponse(result); return null; - }).when(mockInferenceService) - .unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(InferenceAction.Request.DEFAULT_TIMEOUT), any()); + }).when(mockInferenceService).unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); - underTest.validate(mockInferenceService, mockModel, mockActionListener); + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } private void mockNullResponseFromService() { @@ -132,19 +135,13 @@ private void mockFailureResponseFromService(Exception exception) { ActionListener responseListener = ans.getArgument(3); responseListener.onFailure(exception); return null; - }).when(mockInferenceService) - .unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(InferenceAction.Request.DEFAULT_TIMEOUT), any()); + }).when(mockInferenceService).unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); - underTest.validate(mockInferenceService, mockModel, mockActionListener); + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } private void verifyCallToService() { - verify(mockInferenceService).unifiedCompletionInfer( - eq(mockModel), - eq(EXPECTED_REQUEST), - eq(InferenceAction.Request.DEFAULT_TIMEOUT), - any() - ); + verify(mockInferenceService).unifiedCompletionInfer(eq(mockModel), eq(EXPECTED_REQUEST), eq(TIMEOUT), any()); verifyNoMoreInteractions(mockInferenceService, mockModel, mockActionListener, mockInferenceServiceResults); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java index 949d20c7c7ce3..418584d25d085 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; @@ -27,6 +28,9 @@ import static org.mockito.MockitoAnnotations.openMocks; public class SimpleModelValidatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; + @Mock private ServiceIntegrationValidator mockServiceIntegrationValidator; @Mock @@ -49,11 +53,11 @@ public void setup() { public void testValidate_ServiceIntegrationValidatorThrowsException() { doThrow(ElasticsearchStatusException.class).when(mockServiceIntegrationValidator) - .validate(eq(mockInferenceService), eq(mockModel), any()); + .validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); assertThrows( ElasticsearchStatusException.class, - () -> { underTest.validate(mockInferenceService, mockModel, mockActionListener); } + () -> { underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } ); verifyInteractions(); } @@ -66,16 +70,16 @@ public void testValidate_ServiceReturnsInferenceServiceResults() { private void mockCallToServiceIntegrationValidator(InferenceServiceResults results) { doAnswer(ans -> { - ActionListener responseListener = ans.getArgument(2); + ActionListener responseListener = ans.getArgument(3); responseListener.onResponse(results); return null; - }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); - underTest.validate(mockInferenceService, mockModel, mockActionListener); + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } private void verifyInteractions() { - verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); verify(mockActionListener).delegateFailureAndWrap(any()); verifyNoMoreInteractions(mockServiceIntegrationValidator, mockInferenceService, mockModel, mockActionListener); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java index 9ee2201b4f02b..6faa4bd07b6aa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java @@ -9,13 +9,13 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.junit.Before; import org.mockito.Mock; @@ -35,6 +35,7 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase { private static final List TEST_INPUT = List.of("how big"); private static final String TEST_QUERY = "test query"; + private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; @Mock private InferenceService mockInferenceService; @@ -69,13 +70,13 @@ public void testValidate_ServiceThrowsException() { eq(false), eq(Map.of()), eq(InputType.INTERNAL_INGEST), - eq(InferenceAction.Request.DEFAULT_TIMEOUT), + eq(TIMEOUT), any() ); assertThrows( ElasticsearchStatusException.class, - () -> { underTest.validate(mockInferenceService, mockModel, mockActionListener); } + () -> { underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } ); verifyCallToService(false); @@ -112,11 +113,11 @@ private void mockSuccessfulCallToService(String query, InferenceServiceResults r eq(false), eq(Map.of()), eq(InputType.INTERNAL_INGEST), - eq(InferenceAction.Request.DEFAULT_TIMEOUT), + eq(TIMEOUT), any() ); - underTest.validate(mockInferenceService, mockModel, mockActionListener); + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } private void verifyCallToService(boolean withQuery) { @@ -130,7 +131,7 @@ private void verifyCallToService(boolean withQuery) { eq(false), eq(Map.of()), eq(InputType.INTERNAL_INGEST), - eq(InferenceAction.Request.DEFAULT_TIMEOUT), + eq(TIMEOUT), any() ); verifyNoMoreInteractions(mockInferenceService, mockModel, mockActionListener, mockInferenceServiceResults); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java index 10ad38e7eee5c..55a02ebab082b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; @@ -37,6 +38,9 @@ import static org.mockito.MockitoAnnotations.openMocks; public class TextEmbeddingModelValidatorTests extends ESTestCase { + + private static final TimeValue TIMEOUT = TimeValue.ONE_MINUTE; + @Mock private ServiceIntegrationValidator mockServiceIntegrationValidator; @Mock @@ -64,14 +68,14 @@ public void setup() { public void testValidate_ServiceIntegrationValidatorThrowsException() { doThrow(ElasticsearchStatusException.class).when(mockServiceIntegrationValidator) - .validate(eq(mockInferenceService), eq(mockModel), any()); + .validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); assertThrows( ElasticsearchStatusException.class, - () -> { underTest.validate(mockInferenceService, mockModel, mockActionListener); } + () -> { underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); } ); - verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); verify(mockActionListener).delegateFailureAndWrap(any()); verifyNoMoreInteractions(mockServiceIntegrationValidator, mockInferenceService, mockModel, mockActionListener, mockServiceSettings); } @@ -143,14 +147,14 @@ private void mockSuccessfulValidation(Boolean dimensionsSetByUser) { private void mockCallToServiceIntegrationValidator(InferenceServiceResults results) { doAnswer(ans -> { - ActionListener responseListener = ans.getArgument(2); + ActionListener responseListener = ans.getArgument(3); responseListener.onResponse(results); return null; - }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); - underTest.validate(mockInferenceService, mockModel, mockActionListener); + underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); - verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any()); + verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any()); verify(mockActionListener).delegateFailureAndWrap(any()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 7155858a1ac03..7b53dc959d0ea 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -795,175 +795,6 @@ public void testInfer_ThrowsValidationErrorForInvalidInputType() throws IOExcept verifyNoMoreInteractions(sender); } - public void testCheckModelConfig_UpdatesDimensions() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "model": "voyage-3-large", - "object": "list", - "usage": { - "total_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = VoyageAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - "voyage-3-large" - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - VoyageAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - "voyage-3-large" - ) - ) - ); - } - } - - public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "model": "voyage-3-large", - "object": "list", - "usage": { - "total_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = VoyageAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - "voyage-3-large", - (SimilarityMeasure) null - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - VoyageAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - "voyage-3-large", - SimilarityMeasure.DOT_PRODUCT - ) - ) - ); - } - } - - public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new VoyageAIService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "model": "voyage-3-large", - "object": "list", - "usage": { - "total_tokens": 5 - }, - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.123, - -0.123 - ] - } - ] - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = VoyageAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 1, - "voyage-3-large", - SimilarityMeasure.COSINE - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - MatcherAssert.assertThat( - result, - // the dimension is set to 2 because there are 2 embeddings returned from the mock server - is( - VoyageAIEmbeddingsModelTests.createModel( - getUrl(webServer), - "secret", - VoyageAIEmbeddingsTaskSettings.EMPTY_SETTINGS, - 10, - 2, - "voyage-3-large", - SimilarityMeasure.COSINE - ) - ) - ); - } - } - public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException { testUpdateModelWithEmbeddingDetails_Successful(null); }