diff --git a/docs/changelog/130336.yaml b/docs/changelog/130336.yaml new file mode 100644 index 0000000000000..090374287743c --- /dev/null +++ b/docs/changelog/130336.yaml @@ -0,0 +1,5 @@ +pr: 130336 +summary: "[EIS] Rename the elser 2 default model and the default inference endpoint" +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index 3a2a003636b13..5416ccdf3b01d 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -12,10 +12,14 @@ import org.elasticsearch.inference.TaskType; import java.io.IOException; +import java.util.List; +import java.util.Map; import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getAllModels; import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModels; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertTrue; public class InferenceGetModelsWithElasticInferenceServiceIT extends BaseMockEISAuthServerTest { @@ -23,12 +27,19 @@ public void testGetDefaultEndpoints() throws IOException { var allModels = getAllModels(); var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION); - assertThat(allModels, hasSize(4)); + assertThat(allModels, hasSize(5)); assertThat(chatCompletionModels, hasSize(1)); for (var model : chatCompletionModels) { assertEquals("chat_completion", model.get("task_type")); } + assertInferenceIdTaskType(allModels, ".elser-2-elastic", TaskType.SPARSE_EMBEDDING); + } + + private static void assertInferenceIdTaskType(List> models, String inferenceId, TaskType taskType) { + var model = models.stream().filter(m -> m.get("inference_id").equals(inferenceId)).findFirst(); + assertTrue("could not find inference id: " + inferenceId, model.isPresent()); + assertThat(model.get().get("task_type"), is(taskType.toString())); } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index 3ea011c1317cc..a5b3663bf3605 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -34,7 +34,7 @@ public static MockElasticInferenceServiceAuthorizationServer enabledWithRainbowS "task_types": ["chat"] }, { - "model_name": "elser-v2", + "model_name": "elser_model_2", "task_types": ["embed/text/sparse"] } ] diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index cc1556d414215..6477ec07fa4ec 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -42,6 +42,8 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { @@ -190,7 +192,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA "task_types": ["chat"] }, { - "model_name": "elser-v2", + "model_name": "elser_model_2", "task_types": ["embed/text/sparse"] } ] @@ -205,21 +207,17 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.ANY))); assertThat( service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(), - service - ) - ) + containsInAnyOrder( + new InferenceService.DefaultConfigId(".elser-2-elastic", MinimalServiceSettings.sparseEmbedding(), service), + new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service) ) ); assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING))); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-2-elastic")); + assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); var getModelListener = new PlainActionFuture(); // persists the default endpoints @@ -235,7 +233,7 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA { "models": [ { - "model_name": "elser-v2", + "model_name": "elser_model_2", "task_types": ["embed/text/sparse"] } ] @@ -248,7 +246,12 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA ensureAuthorizationCallFinished(service); assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); - assertTrue(service.defaultConfigIds().isEmpty()); + assertThat( + service.defaultConfigIds(), + containsInAnyOrder( + new InferenceService.DefaultConfigId(".elser-2-elastic", MinimalServiceSettings.sparseEmbedding(), service) + ) + ); assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); var getModelListener = new PlainActionFuture(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index d2f3cbb401fe5..a713aae60c3a7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -80,7 +80,11 @@ public class ElasticInferenceService extends SenderService { private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); private static final String SERVICE_NAME = "Elastic"; static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; - static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = Strings.format(".%s-elastic", DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); + static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); + + // elser-2 + static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; + static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId("elser-2"); /** * The task types that the {@link InferenceAction.Request} can accept. @@ -133,6 +137,19 @@ private static Map initDefaultEndpoints( elasticInferenceServiceComponents ), MinimalServiceSettings.chatCompletion() + ), + DEFAULT_ELSER_2_MODEL_ID, + new DefaultModelConfig( + new ElasticInferenceServiceSparseEmbeddingsModel( + DEFAULT_ELSER_ENDPOINT_ID_V2, + TaskType.SPARSE_EMBEDDING, + NAME, + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ), + MinimalServiceSettings.sparseEmbedding() ) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java index c1764b93bfc82..e5af1da030ef1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java @@ -26,8 +26,4 @@ public static boolean isValidModel(String model) { return model != null && VALID_ELSER_MODEL_IDS.contains(model); } - public static boolean isValidEisModel(String model) { - return ELSER_V2_MODEL.equals(model); - } - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index d1070981a6e03..edef1d1615baa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -86,6 +86,10 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.isA; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -960,6 +964,18 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo { "model_name": "rainbow-sprinkles", "task_types": ["chat"] + }, + { + "model_name": "elser_model_2", + "task_types": ["embed/text/sparse"] + }, + { + "model_name": "multilingual-embed-v1", + "task_types": ["embed/text/dense"] + }, + { + "model_name": "rerank-v1", + "task_types": ["rerank/text/text-similarity"] } ] } @@ -976,15 +992,19 @@ public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCo service.defaultConfigIds(), is( List.of( + new InferenceService.DefaultConfigId(".elser-2-elastic", MinimalServiceSettings.sparseEmbedding(), service), new InferenceService.DefaultConfigId(".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(), service) ) ) ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + var models = listener.actionGet(TIMEOUT); + assertThat(models.size(), is(2)); + assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-2-elastic")); + assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java index 5435d5b9a6dad..540ad54f1c7c2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java @@ -24,6 +24,8 @@ import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.junit.Before; @@ -165,6 +167,19 @@ private static Map initDefaultEndpoints() { ElasticInferenceServiceComponents.EMPTY_INSTANCE ), MinimalServiceSettings.chatCompletion() + ), + "elser-2", + new DefaultModelConfig( + new ElasticInferenceServiceSparseEmbeddingsModel( + defaultEndpointId("elser-2"), + TaskType.SPARSE_EMBEDDING, + "test", + new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-2", null, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + ElasticInferenceServiceComponents.EMPTY_INSTANCE + ), + MinimalServiceSettings.sparseEmbedding() ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java index fa0148ac69df5..d9ffddd62fb40 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java @@ -19,21 +19,7 @@ public void testIsValidModel() { assertTrue(org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidModel(randomElserModel())); } - public void testIsValidEisModel() { - assertTrue( - org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidEisModel( - org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL - ) - ); - } - public void testIsInvalidModel() { assertFalse(org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidModel("invalid")); } - - public void testIsInvalidEisModel() { - assertFalse( - org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL_LINUX_X86) - ); - } }