diff --git a/docs/changelog/113873.yaml b/docs/changelog/113873.yaml new file mode 100644 index 0000000000000..ac52aaf94d518 --- /dev/null +++ b/docs/changelog/113873.yaml @@ -0,0 +1,5 @@ +pr: 113873 +summary: Default inference endpoint for ELSER +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/reference/rest-api/usage.asciidoc b/docs/reference/rest-api/usage.asciidoc index 4a8895807f2fa..4dcf0d328e4f1 100644 --- a/docs/reference/rest-api/usage.asciidoc +++ b/docs/reference/rest-api/usage.asciidoc @@ -206,7 +206,12 @@ GET /_xpack/usage "inference": { "available" : true, "enabled" : true, - "models" : [] + "models" : [{ + "service": "elasticsearch", + "task_type": "SPARSE_EMBEDDING", + "count": 1 + } + ] }, "logstash" : { "available" : true, diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index aba644b392cec..cbbfef2cc65fa 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -191,4 +191,13 @@ default Set supportedStreamingTasks() { default boolean canStream(TaskType taskType) { return supportedStreamingTasks().contains(taskType); } + + /** + * A service can define default configurations that can be + * used out of the box without creating an endpoint first. + * @return Default configurations provided by this service + */ + default List defaultConfigs() { + return List.of(); + } } diff --git a/server/src/main/java/org/elasticsearch/inference/UnparsedModel.java b/server/src/main/java/org/elasticsearch/inference/UnparsedModel.java new file mode 100644 index 0000000000000..30a7c6aa2bf9c --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/UnparsedModel.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +import java.util.Map; + +/** + * Semi parsed model where inference entity id, task type and service + * are known but the settings are not parsed. + */ +public record UnparsedModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map settings, + Map secrets +) {} diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java index 7df791bf11559..5adf01a2a0e7d 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java @@ -19,7 +19,8 @@ public enum FeatureFlag { TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null), FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null), CHUNKING_SETTINGS_ENABLED("es.inference_chunking_settings_feature_flag_enabled=true", Version.fromString("8.16.0"), null), - INFERENCE_SCALE_TO_ZERO("es.inference_scale_to_zero_feature_flag_enabled=true", Version.fromString("8.16.0"), null); + INFERENCE_SCALE_TO_ZERO("es.inference_scale_to_zero_feature_flag_enabled=true", Version.fromString("8.16.0"), null), + INFERENCE_DEFAULT_ELSER("es.inference_default_elser_feature_flag_enabled=true", Version.fromString("8.16.0"), null); public final String systemProperty; public final Version from; diff --git a/x-pack/plugin/build.gradle b/x-pack/plugin/build.gradle index b90e8a22ea6c6..feddc7bcfba3f 100644 --- a/x-pack/plugin/build.gradle +++ b/x-pack/plugin/build.gradle @@ -201,5 +201,6 @@ tasks.named("precommit").configure { tasks.named("yamlRestTestV7CompatTransform").configure({ task -> task.skipTest("security/10_forbidden/Test bulk response with invalid credentials", "warning does not exist for compatibility") + task.skipTest("inference/inference_crud/Test get all", "Assertions on number of inference models break due to default configs") }) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index 4a570bfde99a4..34ebdcb7f9f9f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -237,7 +237,9 @@ public int computeNumberOfAllocations() { if (numberOfAllocations != null) { return numberOfAllocations; } else { - if (adaptiveAllocationsSettings == null || adaptiveAllocationsSettings.getMinNumberOfAllocations() == null) { + if (adaptiveAllocationsSettings == null + || adaptiveAllocationsSettings.getMinNumberOfAllocations() == null + || adaptiveAllocationsSettings.getMinNumberOfAllocations() == 0) { return DEFAULT_NUM_ALLOCATIONS; } else { return adaptiveAllocationsSettings.getMinNumberOfAllocations(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java index 19af6a3a4ef4c..d4eace8e96621 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java @@ -147,8 +147,8 @@ public AdaptiveAllocationsSettings merge(AdaptiveAllocationsSettings updates) { public ActionRequestValidationException validate() { ActionRequestValidationException validationException = new ActionRequestValidationException(); boolean hasMinNumberOfAllocations = (minNumberOfAllocations != null && minNumberOfAllocations != -1); - if (hasMinNumberOfAllocations && minNumberOfAllocations < 1) { - validationException.addValidationError("[" + MIN_NUMBER_OF_ALLOCATIONS + "] must be a positive integer or null"); + if (hasMinNumberOfAllocations && minNumberOfAllocations < 0) { + validationException.addValidationError("[" + MIN_NUMBER_OF_ALLOCATIONS + "] must be a non-negative integer or null"); } boolean hasMaxNumberOfAllocations = (maxNumberOfAllocations != null && maxNumberOfAllocations != -1); if (hasMaxNumberOfAllocations && maxNumberOfAllocations < 1) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationSettingsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationSettingsTests.java index c86648f10f08b..d59fbb2a24ee0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationSettingsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationSettingsTests.java @@ -17,7 +17,7 @@ public class AdaptiveAllocationSettingsTests extends AbstractWireSerializingTest public static AdaptiveAllocationsSettings testInstance() { return new AdaptiveAllocationsSettings( randomBoolean() ? null : randomBoolean(), - randomBoolean() ? null : randomIntBetween(1, 2), + randomBoolean() ? null : randomIntBetween(0, 2), randomBoolean() ? null : randomIntBetween(2, 4) ); } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java index 65b7a138e7e1e..c05d08fa33692 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java @@ -85,7 +85,7 @@ public void testSparse() throws IOException { var inferenceId = "sparse-inf"; putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING); - var results = inferOnMockService(inferenceId, List.of("washing", "machine")); + var results = infer(inferenceId, List.of("washing", "machine")); deleteModel(inferenceId); assertNotNull(results.get("sparse_embedding")); } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultElserIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultElserIT.java new file mode 100644 index 0000000000000..5d84aad4b7344 --- /dev/null +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultElserIT.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class DefaultElserIT extends InferenceBaseRestTest { + + private TestThreadPool threadPool; + + @Before + public void createThreadPool() { + threadPool = new TestThreadPool(DefaultElserIT.class.getSimpleName()); + } + + @After + public void tearDown() throws Exception { + threadPool.close(); + super.tearDown(); + } + + @SuppressWarnings("unchecked") + public void testInferCreatesDefaultElser() throws IOException { + assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled()); + var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID); + assertDefaultElserConfig(model); + + var inputs = List.of("Hello World", "Goodnight moon"); + var queryParams = Map.of("timeout", "120s"); + var results = infer(ElasticsearchInternalService.DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, inputs, queryParams); + var embeddings = (List>) results.get("sparse_embedding"); + assertThat(results.toString(), embeddings, hasSize(2)); + } + + @SuppressWarnings("unchecked") + private static void assertDefaultElserConfig(Map modelConfig) { + assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_ELSER_ID, modelConfig.get("inference_id")); + assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service")); + assertEquals(modelConfig.toString(), TaskType.SPARSE_EMBEDDING.toString(), modelConfig.get("task_type")); + + var serviceSettings = (Map) modelConfig.get("service_settings"); + assertThat(modelConfig.toString(), serviceSettings.get("model_id"), is(oneOf(".elser_model_2", ".elser_model_2_linux-x86_64"))); + assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads")); + + var adaptiveAllocations = (Map) serviceSettings.get("adaptive_allocations"); + assertThat( + modelConfig.toString(), + adaptiveAllocations, + Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8)) + ); + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 3fa6159661b7e..f82b6f155c0a0 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -270,7 +270,7 @@ protected Map deployE5TrainedModels() throws IOException { @SuppressWarnings("unchecked") protected Map getModel(String modelId) throws IOException { - var endpoint = Strings.format("_inference/%s", modelId); + var endpoint = Strings.format("_inference/%s?error_trace", modelId); return ((List>) getInternal(endpoint).get("endpoints")).get(0); } @@ -293,9 +293,9 @@ private Map getInternal(String endpoint) throws IOException { return entityAsMap(response); } - protected Map inferOnMockService(String modelId, List input) throws IOException { + protected Map infer(String modelId, List input) throws IOException { var endpoint = Strings.format("_inference/%s", modelId); - return inferOnMockServiceInternal(endpoint, input); + return inferInternal(endpoint, input, Map.of()); } protected Deque streamInferOnMockService(String modelId, TaskType taskType, List input) throws Exception { @@ -324,14 +324,23 @@ public void onFailure(Exception exception) { return responseConsumer.events(); } - protected Map inferOnMockService(String modelId, TaskType taskType, List input) throws IOException { + protected Map infer(String modelId, TaskType taskType, List input) throws IOException { var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); - return inferOnMockServiceInternal(endpoint, input); + return inferInternal(endpoint, input, Map.of()); + } + + protected Map infer(String modelId, TaskType taskType, List input, Map queryParameters) + throws IOException { + var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId); + return inferInternal(endpoint, input, queryParameters); } - private Map inferOnMockServiceInternal(String endpoint, List input) throws IOException { + private Map inferInternal(String endpoint, List input, Map queryParameters) throws IOException { var request = new Request("POST", endpoint); request.setJsonEntity(jsonBody(input)); + if (queryParameters.isEmpty() == false) { + request.addParameters(queryParameters); + } var response = client().performRequest(request); assertOkOrCreated(response); return entityAsMap(response); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 92affbc043669..5a84fd8985504 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -38,10 +38,12 @@ public void testGet() throws IOException { } var getAllModels = getAllModels(); - assertThat(getAllModels, hasSize(9)); + int numModels = DefaultElserFeatureFlag.isEnabled() ? 10 : 9; + assertThat(getAllModels, hasSize(numModels)); var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING); - assertThat(getSparseModels, hasSize(5)); + int numSparseModels = DefaultElserFeatureFlag.isEnabled() ? 6 : 5; + assertThat(getSparseModels, hasSize(numSparseModels)); for (var sparseModel : getSparseModels) { assertEquals("sparse_embedding", sparseModel.get("task_type")); } @@ -99,7 +101,7 @@ public void testApisWithoutTaskType() throws IOException { assertEquals(modelId, singleModel.get("inference_id")); assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type")); - var inference = inferOnMockService(modelId, List.of(randomAlphaOfLength(10))); + var inference = infer(modelId, List.of(randomAlphaOfLength(10))); assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING); deleteModel(modelId); } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java index 5f6bad5687407..1077bfec8bbbd 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java @@ -28,15 +28,12 @@ public void testMockService() throws IOException { } List input = List.of(randomAlphaOfLength(10)); - var inference = inferOnMockService(inferenceEntityId, input); + var inference = infer(inferenceEntityId, input); assertNonEmptyInferenceResults(inference, 1, TaskType.TEXT_EMBEDDING); // Same input should return the same result - assertEquals(inference, inferOnMockService(inferenceEntityId, input)); + assertEquals(inference, infer(inferenceEntityId, input)); // Different input values should not - assertNotEquals( - inference, - inferOnMockService(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))) - ); + assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10))))); } public void testMockServiceWithMultipleInputs() throws IOException { @@ -44,7 +41,7 @@ public void testMockServiceWithMultipleInputs() throws IOException { putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING); // The response is randomly generated, the input can be anything - var inference = inferOnMockService( + var inference = infer( inferenceEntityId, TaskType.TEXT_EMBEDDING, List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15)) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java index 24ba2708f5de4..9a17d8edc0768 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java @@ -30,15 +30,12 @@ public void testMockService() throws IOException { } List input = List.of(randomAlphaOfLength(10)); - var inference = inferOnMockService(inferenceEntityId, input); + var inference = infer(inferenceEntityId, input); assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING); // Same input should return the same result - assertEquals(inference, inferOnMockService(inferenceEntityId, input)); + assertEquals(inference, infer(inferenceEntityId, input)); // Different input values should not - assertNotEquals( - inference, - inferOnMockService(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))) - ); + assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10))))); } public void testMockServiceWithMultipleInputs() throws IOException { @@ -46,7 +43,7 @@ public void testMockServiceWithMultipleInputs() throws IOException { putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); // The response is randomly generated, the input can be anything - var inference = inferOnMockService( + var inference = infer( inferenceEntityId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15)) @@ -84,7 +81,7 @@ public void testMockService_DoesNotReturnHiddenField_InModelResponses() throws I } // The response is randomly generated, the input can be anything - var inference = inferOnMockService(inferenceEntityId, List.of(randomAlphaOfLength(10))); + var inference = infer(inferenceEntityId, List.of(randomAlphaOfLength(10))); assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING); } @@ -102,7 +99,7 @@ public void testMockService_DoesReturnHiddenField_InModelResponses() throws IOEx } // The response is randomly generated, the input can be anything - var inference = inferOnMockService(inferenceEntityId, List.of(randomAlphaOfLength(10))); + var inference = infer(inferenceEntityId, List.of(randomAlphaOfLength(10))); assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING); } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java index 01e8c30e3bf27..8d9c859f129cb 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java @@ -38,7 +38,7 @@ public void testPutE5Small_withPlatformAgnosticVariant() throws IOException { var models = getTrainedModel("_all"); assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId)); - Map results = inferOnMockService( + Map results = infer( inferenceEntityId, TaskType.TEXT_EMBEDDING, List.of("hello world", "this is the second document") @@ -57,7 +57,7 @@ public void testPutE5Small_withPlatformSpecificVariant() throws IOException { var models = getTrainedModel("_all"); assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId)); - Map results = inferOnMockService( + Map results = infer( inferenceEntityId, TaskType.TEXT_EMBEDDING, List.of("hello world", "this is the second document") diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index ea8b32f36f54c..8e68ca9dfa565 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -20,6 +20,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.reindex.ReindexPlugin; import org.elasticsearch.test.ESSingleNodeTestCase; @@ -38,6 +39,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Comparator; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; @@ -110,7 +112,7 @@ public void testGetModel() throws Exception { assertThat(putModelHolder.get(), is(true)); // now get the model - AtomicReference modelHolder = new AtomicReference<>(); + AtomicReference modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); assertThat(modelHolder.get(), not(nullValue())); @@ -168,7 +170,7 @@ public void testDeleteModel() throws Exception { // get should fail deleteResponseHolder.set(false); - AtomicReference modelHolder = new AtomicReference<>(); + AtomicReference modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder); assertThat(exceptionHolder.get(), not(nullValue())); @@ -194,7 +196,7 @@ public void testGetModelsByTaskType() throws InterruptedException { } AtomicReference exceptionHolder = new AtomicReference<>(); - AtomicReference> modelHolder = new AtomicReference<>(); + AtomicReference> modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get(), hasSize(3)); var sparseIds = sparseAndTextEmbeddingModels.stream() @@ -235,8 +237,9 @@ public void testGetAllModels() throws InterruptedException { assertNull(exceptionHolder.get()); } - AtomicReference> modelHolder = new AtomicReference<>(); + AtomicReference> modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder); + assertNull(exceptionHolder.get()); assertThat(modelHolder.get(), hasSize(modelCount)); var getAllModels = modelHolder.get(); @@ -264,15 +267,213 @@ public void testGetModelWithSecrets() throws InterruptedException { assertThat(putModelHolder.get(), is(true)); assertNull(exceptionHolder.get()); - AtomicReference modelHolder = new AtomicReference<>(); + AtomicReference modelHolder = new AtomicReference<>(); blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), hasSize(1)); var secretSettings = (Map) modelHolder.get().secrets().get("secret_settings"); assertThat(secretSettings.get("secret"), equalTo(secret)); + assertReturnModelIsModifiable(modelHolder.get()); // get model without secrets blockingCall(listener -> modelRegistry.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder); assertThat(modelHolder.get().secrets().keySet(), empty()); + assertReturnModelIsModifiable(modelHolder.get()); + } + + public void testGetAllModels_WithDefaults() throws Exception { + var service = "foo"; + var secret = "abc"; + int configuredModelCount = 10; + int defaultModelCount = 2; + int totalModelCount = 12; + + var defaultConfigs = new HashMap(); + for (int i = 0; i < defaultModelCount; i++) { + var id = "default-" + i; + defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret)); + } + defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration); + + AtomicReference putModelHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + var createdModels = new HashMap(); + for (int i = 0; i < configuredModelCount; i++) { + var id = randomAlphaOfLength(5) + i; + var model = createModel(id, randomFrom(TaskType.values()), service); + createdModels.put(id, model); + blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + assertNull(exceptionHolder.get()); + } + + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder); + assertNull(exceptionHolder.get()); + assertThat(modelHolder.get(), hasSize(totalModelCount)); + var getAllModels = modelHolder.get(); + assertReturnModelIsModifiable(modelHolder.get().get(0)); + + // sort in the same order as the returned models + var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList()); + ids.addAll(createdModels.keySet().stream().toList()); + ids.sort(String::compareTo); + for (int i = 0; i < totalModelCount; i++) { + var id = ids.get(i); + assertEquals(id, getAllModels.get(i).inferenceEntityId()); + if (id.startsWith("default")) { + assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType()); + assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service()); + } else { + assertEquals(createdModels.get(id).getTaskType(), getAllModels.get(i).taskType()); + assertEquals(createdModels.get(id).getConfigurations().getService(), getAllModels.get(i).service()); + } + } + } + + public void testGetAllModels_OnlyDefaults() throws Exception { + var service = "foo"; + var secret = "abc"; + int defaultModelCount = 2; + + var defaultConfigs = new HashMap(); + for (int i = 0; i < defaultModelCount; i++) { + var id = "default-" + i; + defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret)); + } + defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration); + + AtomicReference exceptionHolder = new AtomicReference<>(); + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder); + assertNull(exceptionHolder.get()); + assertThat(modelHolder.get(), hasSize(2)); + var getAllModels = modelHolder.get(); + assertReturnModelIsModifiable(modelHolder.get().get(0)); + + // sort in the same order as the returned models + var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList()); + ids.sort(String::compareTo); + for (int i = 0; i < defaultModelCount; i++) { + var id = ids.get(i); + assertEquals(id, getAllModels.get(i).inferenceEntityId()); + assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType()); + assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service()); + } + } + + public void testGet_WithDefaults() throws InterruptedException { + var service = "foo"; + var secret = "abc"; + + var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret); + var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret); + + modelRegistry.addDefaultConfiguration(defaultSparse); + modelRegistry.addDefaultConfiguration(defaultText); + + AtomicReference putModelHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service); + var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service); + blockingCall(listener -> modelRegistry.storeModel(configured1, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + blockingCall(listener -> modelRegistry.storeModel(configured2, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + assertNull(exceptionHolder.get()); + + AtomicReference modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModel("default-sparse", listener), modelHolder, exceptionHolder); + assertEquals("default-sparse", modelHolder.get().inferenceEntityId()); + assertEquals(TaskType.SPARSE_EMBEDDING, modelHolder.get().taskType()); + assertReturnModelIsModifiable(modelHolder.get()); + + blockingCall(listener -> modelRegistry.getModel("default-text", listener), modelHolder, exceptionHolder); + assertEquals("default-text", modelHolder.get().inferenceEntityId()); + assertEquals(TaskType.TEXT_EMBEDDING, modelHolder.get().taskType()); + + blockingCall(listener -> modelRegistry.getModel(configured1.getInferenceEntityId(), listener), modelHolder, exceptionHolder); + assertEquals(configured1.getInferenceEntityId(), modelHolder.get().inferenceEntityId()); + assertEquals(configured1.getTaskType(), modelHolder.get().taskType()); + } + + public void testGetByTaskType_WithDefaults() throws Exception { + var service = "foo"; + var secret = "abc"; + + var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret); + var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret); + var defaultChat = createUnparsedConfig("default-chat", TaskType.COMPLETION, service, secret); + + modelRegistry.addDefaultConfiguration(defaultSparse); + modelRegistry.addDefaultConfiguration(defaultText); + modelRegistry.addDefaultConfiguration(defaultChat); + + AtomicReference putModelHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, service); + var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, service); + var configuredRerank = createModel("configured-rerank", TaskType.RERANK, service); + blockingCall(listener -> modelRegistry.storeModel(configuredSparse, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + blockingCall(listener -> modelRegistry.storeModel(configuredText, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + blockingCall(listener -> modelRegistry.storeModel(configuredRerank, listener), putModelHolder, exceptionHolder); + assertThat(putModelHolder.get(), is(true)); + assertNull(exceptionHolder.get()); + + AtomicReference> modelHolder = new AtomicReference<>(); + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder); + if (exceptionHolder.get() != null) { + throw exceptionHolder.get(); + } + assertNull(exceptionHolder.get()); + assertThat(modelHolder.get(), hasSize(2)); + assertEquals("configured-sparse", modelHolder.get().get(0).inferenceEntityId()); + assertEquals("default-sparse", modelHolder.get().get(1).inferenceEntityId()); + + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder); + assertThat(modelHolder.get(), hasSize(2)); + assertEquals("configured-text", modelHolder.get().get(0).inferenceEntityId()); + assertEquals("default-text", modelHolder.get().get(1).inferenceEntityId()); + assertReturnModelIsModifiable(modelHolder.get().get(0)); + + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.RERANK, listener), modelHolder, exceptionHolder); + assertThat(modelHolder.get(), hasSize(1)); + assertEquals("configured-rerank", modelHolder.get().get(0).inferenceEntityId()); + + blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.COMPLETION, listener), modelHolder, exceptionHolder); + assertThat(modelHolder.get(), hasSize(1)); + assertEquals("default-chat", modelHolder.get().get(0).inferenceEntityId()); + assertReturnModelIsModifiable(modelHolder.get().get(0)); + } + + @SuppressWarnings("unchecked") + private void assertReturnModelIsModifiable(UnparsedModel unparsedModel) { + var settings = unparsedModel.settings(); + if (settings != null) { + var serviceSettings = (Map) settings.get("service_settings"); + if (serviceSettings != null && serviceSettings.size() > 0) { + var itr = serviceSettings.entrySet().iterator(); + itr.next(); + itr.remove(); + } + + var taskSettings = (Map) settings.get("task_settings"); + if (taskSettings != null && taskSettings.size() > 0) { + var itr = taskSettings.entrySet().iterator(); + itr.next(); + itr.remove(); + } + + if (unparsedModel.secrets() != null && unparsedModel.secrets().size() > 0) { + var itr = unparsedModel.secrets().entrySet().iterator(); + itr.next(); + itr.remove(); + } + } } private Model buildElserModelConfig(String inferenceEntityId, TaskType taskType) { @@ -327,6 +528,10 @@ public static Model createModelWithSecrets(String inferenceEntityId, TaskType ta ); } + public static UnparsedModel createUnparsedConfig(String inferenceEntityId, TaskType taskType, String service, String secret) { + return new UnparsedModel(inferenceEntityId, taskType, service, Map.of("a", "b"), Map.of("secret", secret)); + } + private static class TestModelOfAnyKind extends ModelConfigurations { record TestModelServiceSettings() implements ServiceSettings { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/DefaultElserFeatureFlag.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/DefaultElserFeatureFlag.java new file mode 100644 index 0000000000000..2a764dabd62ae --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/DefaultElserFeatureFlag.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.common.util.FeatureFlag; + +public class DefaultElserFeatureFlag { + + private DefaultElserFeatureFlag() {} + + private static final FeatureFlag FEATURE_FLAG = new FeatureFlag("inference_default_elser"); + + public static boolean isEnabled() { + return FEATURE_FLAG.isEnabled(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 0ab395f4bfa39..dbb9130ab91e1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -210,6 +210,9 @@ public Collection createComponents(PluginServices services) { // reference correctly var registry = new InferenceServiceRegistry(inferenceServices, factoryContext); registry.init(services.client()); + for (var service : registry.getServices().values()) { + service.defaultConfigs().forEach(modelRegistry::addDefaultConfiguration); + } inferenceServiceRegistry.set(registry); var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index 3c893f8870627..829a6b6c67ff9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -24,6 +24,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -91,7 +92,7 @@ private void doExecuteForked( ClusterState state, ActionListener masterListener ) { - SubscribableListener.newForked(modelConfigListener -> { + SubscribableListener.newForked(modelConfigListener -> { // Get the model from the registry modelRegistry.getModel(request.getInferenceEndpointId(), modelConfigListener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index a1f33afa05b5c..5ee1e40869dbc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -112,7 +113,7 @@ private void getModelsByTaskType(TaskType taskType, ActionListener unparsedModels) { + private GetInferenceModelAction.Response parseModels(List unparsedModels) { var parsedModels = new ArrayList(); for (var unparsedModel : unparsedModels) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index d2a73b7df77c1..4045734546596 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -19,6 +19,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -64,30 +65,16 @@ public TransportInferenceAction( @Override protected void doExecute(Task task, InferenceAction.Request request, ActionListener listener) { - ActionListener getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> { + ActionListener getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> { var service = serviceRegistry.getService(unparsedModel.service()); if (service.isEmpty()) { - delegate.onFailure( - new ElasticsearchStatusException( - "Unknown service [{}] for model [{}]. ", - RestStatus.INTERNAL_SERVER_ERROR, - unparsedModel.service(), - unparsedModel.inferenceEntityId() - ) - ); + listener.onFailure(unknownServiceException(unparsedModel.service(), request.getInferenceEntityId())); return; } if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) { // not the wildcard task type and not the model task type - delegate.onFailure( - new ElasticsearchStatusException( - "Incompatible task_type, the requested type [{}] does not match the model type [{}]", - RestStatus.BAD_REQUEST, - request.getTaskType(), - unparsedModel.taskType() - ) - ); + listener.onFailure(incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType())); return; } @@ -98,7 +85,6 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe unparsedModel.settings(), unparsedModel.secrets() ); - inferenceStats.incrementRequestCount(model); inferOnService(model, request, service.get(), delegate); }); @@ -112,6 +98,7 @@ private void inferOnService( ActionListener listener ) { if (request.isStreaming() == false || service.canStream(request.getTaskType())) { + inferenceStats.incrementRequestCount(model); service.infer( model, request.getQuery(), @@ -160,5 +147,19 @@ private ActionListener createListener( }); } return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults))); - }; + } + + private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) { + return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId); + } + + private static ElasticsearchStatusException incompatibleTaskTypeException(TaskType requested, TaskType expected) { + return new ElasticsearchStatusException( + "Incompatible task_type, the requested type [{}] does not match the model type [{}]", + RestStatus.BAD_REQUEST, + requested, + expected + ); + } + } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index ade0748ef10bf..a4eb94c2674d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -35,6 +35,7 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; @@ -211,9 +212,9 @@ private void executeShardBulkInferenceAsync( final Releasable onFinish ) { if (inferenceProvider == null) { - ActionListener modelLoadingListener = new ActionListener<>() { + ActionListener modelLoadingListener = new ActionListener<>() { @Override - public void onResponse(ModelRegistry.UnparsedModel unparsedModel) { + public void onResponse(UnparsedModel unparsedModel) { var service = inferenceServiceRegistry.getService(unparsedModel.service()); if (service.isEmpty() == false) { var provider = new InferenceProvider( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index a6e4fcae7169f..d756c0ef26f14 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -32,6 +32,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; @@ -48,6 +49,8 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Function; @@ -58,32 +61,19 @@ public class ModelRegistry { public record ModelConfigMap(Map config, Map secrets) {} - /** - * Semi parsed model where inference entity id, task type and service - * are known but the settings are not parsed. - */ - public record UnparsedModel( - String inferenceEntityId, - TaskType taskType, - String service, - Map settings, - Map secrets - ) { - - public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) { - if (modelConfigMap.config() == null) { - throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); - } - String inferenceEntityId = ServiceUtils.removeStringOrThrowIfNull( - modelConfigMap.config(), - ModelConfigurations.INDEX_ONLY_ID_FIELD_NAME - ); - String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); - String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); - TaskType taskType = TaskType.fromString(taskTypeStr); - - return new UnparsedModel(inferenceEntityId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); + public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) { + if (modelConfigMap.config() == null) { + throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST); } + String inferenceEntityId = ServiceUtils.removeStringOrThrowIfNull( + modelConfigMap.config(), + ModelConfigurations.INDEX_ONLY_ID_FIELD_NAME + ); + String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE); + String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME); + TaskType taskType = TaskType.fromString(taskTypeStr); + + return new UnparsedModel(inferenceEntityId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets()); } private static final String TASK_TYPE_FIELD = "task_type"; @@ -91,9 +81,27 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private static final Logger logger = LogManager.getLogger(ModelRegistry.class); private final OriginSettingClient client; + private Map defaultConfigs; public ModelRegistry(Client client) { this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); + this.defaultConfigs = new HashMap<>(); + } + + public void addDefaultConfiguration(UnparsedModel serviceDefaultConfig) { + if (defaultConfigs.containsKey(serviceDefaultConfig.inferenceEntityId())) { + throw new IllegalStateException( + "Cannot add default endpoint to the inference endpoint registry with duplicate inference id [" + + serviceDefaultConfig.inferenceEntityId() + + "] declared by service [" + + serviceDefaultConfig.service() + + "]. The inference Id is already use by [" + + defaultConfigs.get(serviceDefaultConfig.inferenceEntityId()).service() + + "] service." + ); + } + + defaultConfigs.put(serviceDefaultConfig.inferenceEntityId(), serviceDefaultConfig); } /** @@ -102,6 +110,11 @@ public ModelRegistry(Client client) { * @param listener Model listener */ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { + if (defaultConfigs.containsKey(inferenceEntityId)) { + listener.onResponse(deepCopyDefaultConfig(defaultConfigs.get(inferenceEntityId))); + return; + } + ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets if (searchResponse.getHits().getHits().length == 0) { @@ -109,7 +122,7 @@ public void getModelWithSecrets(String inferenceEntityId, ActionListener listener) { + if (defaultConfigs.containsKey(inferenceEntityId)) { + listener.onResponse(deepCopyDefaultConfig(defaultConfigs.get(inferenceEntityId))); + return; + } + ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { // There should be a hit for the configurations and secrets if (searchResponse.getHits().getHits().length == 0) { @@ -135,7 +153,7 @@ public void getModel(String inferenceEntityId, ActionListener lis return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); assert modelConfigs.size() == 1; delegate.onResponse(modelConfigs.get(0)); }); @@ -162,14 +180,29 @@ private ResourceNotFoundException inferenceNotFoundException(String inferenceEnt */ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { + var defaultConfigsForTaskType = defaultConfigs.values() + .stream() + .filter(m -> m.taskType() == taskType) + .map(ModelRegistry::deepCopyDefaultConfig) + .toList(); + // Not an error if no models of this task_type - if (searchResponse.getHits().getHits().length == 0) { + if (searchResponse.getHits().getHits().length == 0 && defaultConfigsForTaskType.isEmpty()) { delegate.onResponse(List.of()); return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); - delegate.onResponse(modelConfigs); + var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); + + if (defaultConfigsForTaskType.isEmpty() == false) { + var allConfigs = new ArrayList(); + allConfigs.addAll(modelConfigs); + allConfigs.addAll(defaultConfigsForTaskType); + allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId)); + delegate.onResponse(allConfigs); + } else { + delegate.onResponse(modelConfigs); + } }); QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString())); @@ -191,14 +224,19 @@ public void getModelsByTaskType(TaskType taskType, ActionListener> listener) { ActionListener searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> { - // Not an error if no models of this task_type - if (searchResponse.getHits().getHits().length == 0) { + var defaults = defaultConfigs.values().stream().map(ModelRegistry::deepCopyDefaultConfig).toList(); + + if (searchResponse.getHits().getHits().length == 0 && defaults.isEmpty()) { delegate.onResponse(List.of()); return; } - var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList(); - delegate.onResponse(modelConfigs); + var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList(); + var allConfigs = new ArrayList(); + allConfigs.addAll(foundConfigs); + allConfigs.addAll(defaults); + allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId)); + delegate.onResponse(allConfigs); }); // In theory the index should only contain model config documents @@ -216,7 +254,7 @@ public void getAllModels(ActionListener> listener) { client.search(modelSearch, searchListener); } - private List parseHitsAsModels(SearchHits hits) { + private ArrayList parseHitsAsModels(SearchHits hits) { var modelConfigs = new ArrayList(); for (var hit : hits) { modelConfigs.add(new ModelConfigMap(hit.getSourceAsMap(), Map.of())); @@ -393,4 +431,57 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T private QueryBuilder documentIdQuery(String inferenceEntityId) { return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(inferenceEntityId))); } + + static UnparsedModel deepCopyDefaultConfig(UnparsedModel other) { + // Because the default config uses immutable maps + return new UnparsedModel( + other.inferenceEntityId(), + other.taskType(), + other.service(), + copySettingsMap(other.settings()), + copySecretsMap(other.secrets()) + ); + } + + @SuppressWarnings("unchecked") + static Map copySettingsMap(Map other) { + var result = new HashMap(); + + var serviceSettings = (Map) other.get(ModelConfigurations.SERVICE_SETTINGS); + if (serviceSettings != null) { + var copiedServiceSettings = copyMap1LevelDeep(serviceSettings); + result.put(ModelConfigurations.SERVICE_SETTINGS, copiedServiceSettings); + } + + var taskSettings = (Map) other.get(ModelConfigurations.TASK_SETTINGS); + if (taskSettings != null) { + var copiedTaskSettings = copyMap1LevelDeep(taskSettings); + result.put(ModelConfigurations.TASK_SETTINGS, copiedTaskSettings); + } + + var chunkSettings = (Map) other.get(ModelConfigurations.CHUNKING_SETTINGS); + if (chunkSettings != null) { + var copiedChunkSettings = copyMap1LevelDeep(chunkSettings); + result.put(ModelConfigurations.CHUNKING_SETTINGS, copiedChunkSettings); + } + + return result; + } + + static Map copySecretsMap(Map other) { + return copyMap1LevelDeep(other); + } + + @SuppressWarnings("unchecked") + static Map copyMap1LevelDeep(Map other) { + var result = new HashMap(); + for (var entry : other.entrySet()) { + if (entry.getValue() instanceof Map) { + result.put(entry.getKey(), new HashMap<>((Map) entry.getValue())); + } else { + result.put(entry.getKey(), entry.getValue()); + } + } + return result; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index 23e806e01300a..0dd41db2f016c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.SubscribableListener; @@ -31,6 +32,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; +import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag; import org.elasticsearch.xpack.inference.InferencePlugin; import java.io.IOException; @@ -80,7 +82,6 @@ public BaseElasticsearchInternalService( @Override public void start(Model model, ActionListener finalListener) { if (model instanceof ElasticsearchInternalModel esModel) { - if (supportedTaskTypes().contains(model.getTaskType()) == false) { finalListener.onFailure( new IllegalStateException(TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name())) @@ -149,7 +150,7 @@ public void putModel(Model model, ActionListener listener) { } } - private void putBuiltInModel(String modelId, ActionListener listener) { + protected void putBuiltInModel(String modelId, ActionListener listener) { var input = new TrainedModelInput(List.of("text_field")); // by convention text_field is used var config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).validate(true).build(); PutTrainedModelAction.Request putRequest = new PutTrainedModelAction.Request(config, false, true); @@ -258,4 +259,27 @@ public static InferModelAction.Request buildInferenceRequest( request.setChunked(chunk); return request; } + + protected abstract boolean isDefaultId(String inferenceId); + + protected void maybeStartDeployment( + ElasticsearchInternalModel model, + Exception e, + InferModelAction.Request request, + ActionListener listener + ) { + if (DefaultElserFeatureFlag.isEnabled() == false) { + listener.onFailure(e); + return; + } + + if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + this.start( + model, + listener.delegateFailureAndWrap((l, started) -> { client.execute(InferModelAction.INSTANCE, request, listener); }) + ); + } else { + listener.onFailure(e); + } + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index e274c641e30be..dd14e16412996 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -26,6 +26,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; @@ -73,6 +74,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 ); + public static final String DEFAULT_ELSER_ID = ".elser-2"; + private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class); @@ -100,6 +103,17 @@ public void parseRequestConfig( Map config, ActionListener modelListener ) { + if (inferenceEntityId.equals(DEFAULT_ELSER_ID)) { + modelListener.onFailure( + new ElasticsearchStatusException( + "[{}] is a reserved inference Id. Cannot create a new inference endpoint with a reserved Id", + RestStatus.BAD_REQUEST, + inferenceEntityId + ) + ); + return; + } + try { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMap(config, ModelConfigurations.TASK_SETTINGS); @@ -459,20 +473,24 @@ public void infer( TimeValue timeout, ActionListener listener ) { - var taskType = model.getConfigurations().getTaskType(); - if (TaskType.TEXT_EMBEDDING.equals(taskType)) { - inferTextEmbedding(model, input, inputType, timeout, listener); - } else if (TaskType.RERANK.equals(taskType)) { - inferRerank(model, query, input, inputType, timeout, taskSettings, listener); - } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { - inferSparseEmbedding(model, input, inputType, timeout, listener); + if (model instanceof ElasticsearchInternalModel esModel) { + var taskType = model.getConfigurations().getTaskType(); + if (TaskType.TEXT_EMBEDDING.equals(taskType)) { + inferTextEmbedding(esModel, input, inputType, timeout, listener); + } else if (TaskType.RERANK.equals(taskType)) { + inferRerank(esModel, query, input, inputType, timeout, taskSettings, listener); + } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) { + inferSparseEmbedding(esModel, input, inputType, timeout, listener); + } else { + throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); + } } else { - throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); + listener.onFailure(notElasticsearchModelException(model)); } } public void inferTextEmbedding( - Model model, + ElasticsearchInternalModel model, List inputs, InputType inputType, TimeValue timeout, @@ -487,17 +505,19 @@ public void inferTextEmbedding( false ); - client.execute( - InferModelAction.INSTANCE, - request, - listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(InferenceTextEmbeddingFloatResults.of(inferenceResult.getInferenceResults())) - ) + ActionListener mlResultsListener = listener.delegateFailureAndWrap( + (l, inferenceResult) -> l.onResponse(InferenceTextEmbeddingFloatResults.of(inferenceResult.getInferenceResults())) + ); + + var maybeDeployListener = mlResultsListener.delegateResponse( + (l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener) ); + + client.execute(InferModelAction.INSTANCE, request, maybeDeployListener); } public void inferSparseEmbedding( - Model model, + ElasticsearchInternalModel model, List inputs, InputType inputType, TimeValue timeout, @@ -512,17 +532,19 @@ public void inferSparseEmbedding( false ); - client.execute( - InferModelAction.INSTANCE, - request, - listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(SparseEmbeddingResults.of(inferenceResult.getInferenceResults())) - ) + ActionListener mlResultsListener = listener.delegateFailureAndWrap( + (l, inferenceResult) -> l.onResponse(SparseEmbeddingResults.of(inferenceResult.getInferenceResults())) + ); + + var maybeDeployListener = mlResultsListener.delegateResponse( + (l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener) ); + + client.execute(InferModelAction.INSTANCE, request, maybeDeployListener); } public void inferRerank( - Model model, + ElasticsearchInternalModel model, String query, List inputs, InputType inputType, @@ -671,4 +693,42 @@ private RankedDocsResults textSimilarityResultsToRankedDocs( Collections.sort(rankings); return new RankedDocsResults(rankings); } + + @Override + public List defaultConfigs() { + // TODO Chunking settings + Map elserSettings = Map.of( + ModelConfigurations.SERVICE_SETTINGS, + Map.of( + ElasticsearchInternalServiceSettings.MODEL_ID, + ElserModels.ELSER_V2_MODEL, // TODO pick model depending on platform + ElasticsearchInternalServiceSettings.NUM_THREADS, + 1, + ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS, + Map.of( + "enabled", + Boolean.TRUE, + "min_number_of_allocations", + 1, + "max_number_of_allocations", + 8 // no max? + ) + ) + ); + + return List.of( + new UnparsedModel( + DEFAULT_ELSER_ID, + TaskType.SPARSE_EMBEDDING, + NAME, + elserSettings, + Map.of() // no secrets + ) + ); + } + + @Override + protected boolean isDefaultId(String inferenceId) { + return DEFAULT_ELSER_ID.equals(inferenceId); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index d78ea7933e836..770e6e3cb9cf4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; @@ -266,12 +267,11 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool ModelRegistry modelRegistry = mock(ModelRegistry.class); Answer unparsedModelAnswer = invocationOnMock -> { String id = (String) invocationOnMock.getArguments()[0]; - ActionListener listener = (ActionListener) invocationOnMock - .getArguments()[1]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; var model = modelMap.get(id); if (model != null) { listener.onResponse( - new ModelRegistry.UnparsedModel( + new UnparsedModel( model.getInferenceEntityId(), model.getTaskType(), model.getServiceSettings().model(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index fbd8ccd621559..75c370fd4d3fb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchResponseUtils; @@ -38,9 +39,12 @@ import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.sameInstance; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -68,7 +72,7 @@ public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT)); @@ -82,7 +86,7 @@ public void testGetUnparsedModelMap_ThrowsIllegalArgumentException_WhenInvalidIn var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> listener.actionGet(TIMEOUT)); @@ -99,7 +103,7 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -116,7 +120,7 @@ public void testGetUnparsedModelMap_ThrowsIllegalStateException_WhenUnableToFind var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT)); @@ -150,7 +154,7 @@ public void testGetModelWithSecrets() { var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModelWithSecrets("1", listener); var modelConfig = listener.actionGet(TIMEOUT); @@ -179,7 +183,7 @@ public void testGetModelNoSecrets() { var registry = new ModelRegistry(client); - var listener = new PlainActionFuture(); + var listener = new PlainActionFuture(); registry.getModel("1", listener); registry.getModel("1", listener); @@ -288,6 +292,80 @@ public void testStoreModel_ThrowsException_WhenFailureIsNotAVersionConflict() { ); } + @SuppressWarnings("unchecked") + public void testDeepCopyDefaultConfig() { + { + var toCopy = new UnparsedModel("tocopy", randomFrom(TaskType.values()), "service-a", Map.of(), Map.of()); + var copied = ModelRegistry.deepCopyDefaultConfig(toCopy); + assertThat(copied, not(sameInstance(toCopy))); + assertThat(copied.taskType(), is(toCopy.taskType())); + assertThat(copied.service(), is(toCopy.service())); + assertThat(copied.secrets(), not(sameInstance(toCopy.secrets()))); + assertThat(copied.secrets(), is(toCopy.secrets())); + // Test copied is a modifiable map + copied.secrets().put("foo", "bar"); + + assertThat(copied.settings(), not(sameInstance(toCopy.settings()))); + assertThat(copied.settings(), is(toCopy.settings())); + // Test copied is a modifiable map + copied.settings().put("foo", "bar"); + } + + { + Map secretsMap = Map.of("secret", "value"); + Map chunking = Map.of("strategy", "word"); + Map task = Map.of("user", "name"); + Map service = Map.of("num_threads", 1, "adaptive_allocations", Map.of("enabled", true)); + Map settings = Map.of("chunking_settings", chunking, "service_settings", service, "task_settings", task); + + var toCopy = new UnparsedModel("tocopy", randomFrom(TaskType.values()), "service-a", settings, secretsMap); + var copied = ModelRegistry.deepCopyDefaultConfig(toCopy); + assertThat(copied, not(sameInstance(toCopy))); + + assertThat(copied.secrets(), not(sameInstance(toCopy.secrets()))); + assertThat(copied.secrets(), is(toCopy.secrets())); + // Test copied is a modifiable map + copied.secrets().remove("secret"); + + assertThat(copied.settings(), not(sameInstance(toCopy.settings()))); + assertThat(copied.settings(), is(toCopy.settings())); + // Test copied is a modifiable map + var chunkOut = (Map) copied.settings().get("chunking_settings"); + assertThat(chunkOut, is(chunking)); + chunkOut.remove("strategy"); + + var taskOut = (Map) copied.settings().get("task_settings"); + assertThat(taskOut, is(task)); + taskOut.remove("user"); + + var serviceOut = (Map) copied.settings().get("service_settings"); + assertThat(serviceOut, is(service)); + var adaptiveOut = (Map) serviceOut.remove("adaptive_allocations"); + assertThat(adaptiveOut, is(Map.of("enabled", true))); + adaptiveOut.remove("enabled"); + } + } + + public void testDuplicateDefaultIds() { + var client = mockBulkClient(); + var registry = new ModelRegistry(client); + + var id = "my-inference"; + + registry.addDefaultConfiguration(new UnparsedModel(id, randomFrom(TaskType.values()), "service-a", Map.of(), Map.of())); + var ise = expectThrows( + IllegalStateException.class, + () -> registry.addDefaultConfiguration(new UnparsedModel(id, randomFrom(TaskType.values()), "service-b", Map.of(), Map.of())) + ); + assertThat( + ise.getMessage(), + containsString( + "Cannot add default endpoint to the inference endpoint registry with duplicate inference id [my-inference] declared by " + + "service [service-b]. The inference Id is already use by [service-a] service." + ) + ); + } + private Client mockBulkClient() { var client = mockClient(); when(client.prepareBulk()).thenReturn(new BulkRequestBuilder(client)); diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml index 6aec721b35418..11be68cc764e2 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml @@ -44,15 +44,18 @@ - do: inference.get: inference_id: "*" - - length: { endpoints: 0} + - length: { endpoints: 1} + - match: { endpoints.0.inference_id: ".elser-2" } - do: inference.get: inference_id: _all - - length: { endpoints: 0} + - length: { endpoints: 1} + - match: { endpoints.0.inference_id: ".elser-2" } - do: inference.get: inference_id: "" - - length: { endpoints: 0} + - length: { endpoints: 1} + - match: { endpoints.0.inference_id: ".elser-2" }