diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 5cd9c2609b94c..333bdd76d1838 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -112,21 +112,26 @@ private Iterable providersFor(TaskType taskType) throws IOException { } public void testGetServicesWithRerankTaskType() throws IOException { + List services = getServices(TaskType.RERANK); + assertThat(services.size(), equalTo(11)); + + var providers = providers(services); + assertThat( providersFor(TaskType.RERANK), containsInAnyOrder( List.of( "alibabacloud-ai-search", + "amazon_sagemaker", "cohere", "custom", "elastic", "elasticsearch", "googlevertexai", + "hugging_face", "jinaai", "test_reranking_service", - "voyageai", - "hugging_face", - "amazon_sagemaker" + "voyageai" ).toArray() ) ); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index 7f0212167f8ac..f2986e3c5aadf 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -46,7 +46,7 @@ public void enqueueAuthorizeAllModelsResponse() { "model_name": "multilingual-embed-v1", "task_types": ["embed/text/dense"] }, - { + { "model_name": "rerank-v1", "task_types": ["rerank/text/text-similarity"] } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 4c200c6f20247..cfd1d324d7d54 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -283,6 +283,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA { "model_name": "rerank-v1", "task_types": ["rerank/text/text-similarity"] + }, { "model_name": "multilingual-embed-v1", @@ -299,28 +300,29 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); assertThat( - service.defaultConfigIds(), - containsInAnyOrder( - new InferenceService.DefaultConfigId( - ".elser-v2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".multilingual-embed-v1-elastic", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service + service.supportedTaskTypes(), + is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)) + ); + containsInAnyOrder( + new InferenceService.DefaultConfigId( + ".elser-v2-elastic", + MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), + service + ), + new InferenceService.DefaultConfigId( + ".multilingual-embed-v1-elastic", + MinimalServiceSettings.textEmbedding( + ElasticInferenceService.NAME, + ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), + DenseVectorFieldMapper.ElementType.FLOAT ), - new InferenceService.DefaultConfigId( - ".rerank-v1-elastic", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) + service + ), + new InferenceService.DefaultConfigId( + ".rerank-v1-elastic", + MinimalServiceSettings.rerank(ElasticInferenceService.NAME), + service ) ); assertThat( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 443fe6f0ec038..f2c47b554aed9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -210,7 +210,6 @@ private static Map initDefaultEndpoints( DenseVectorFieldMapper.ElementType.FLOAT ) ), - DEFAULT_RERANK_MODEL_ID_V1, new DefaultModelConfig( new ElasticInferenceServiceRerankModel( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java index 7e592406a718a..38e71d74b1716 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java @@ -87,7 +87,7 @@ public URI uri() { private URI createUri() throws ElasticsearchStatusException { try { // TODO, consider transforming the base URL into a URI for better error handling. - return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank"); + return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank/text/text-similarity"); } catch (URISyntaxException e) { throw new ElasticsearchStatusException( "Failed to create URI for service [" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 90a38f33d3e88..7372bb8d53953 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -1315,6 +1315,7 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); var models = listener.actionGet(TIMEOUT); + assertThat(models.size(), is(4)); assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic")); assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".multilingual-embed-v1-elastic"));