diff --git a/docs/changelog/116352.yaml b/docs/changelog/116352.yaml new file mode 100644 index 0000000000000..1b60f7b5c31c1 --- /dev/null +++ b/docs/changelog/116352.yaml @@ -0,0 +1,5 @@ +pr: 116352 +summary: Add endpoint creation validation for `ElasticsearchInternalService` +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index c6e09f61befa0..59fc694f72348 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -147,10 +147,10 @@ void chunkedInfer( /** * Stop the model deployment. * The default action does nothing except acknowledge the request (true). - * @param unparsedModel The unparsed model configuration + * @param model The model configuration * @param listener The listener */ - default void stop(UnparsedModel unparsedModel, ActionListener listener) { + default void stop(Model model, ActionListener listener) { listener.onResponse(true); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index c1dbd8cfec9d5..4e80d01f85543 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -115,7 +115,9 @@ private void doExecuteForked( var service = serviceRegistry.getService(unparsedModel.service()); if (service.isPresent()) { - service.get().stop(unparsedModel, listener); + var model = service.get() + .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); + service.get().stop(model, listener); } else { listener.onFailure( new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index fc070965f29c2..1520a0c447b7a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -22,7 +22,6 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; @@ -118,9 +117,7 @@ public void start(Model model, TimeValue timeout, ActionListener finalL } @Override - public void stop(UnparsedModel unparsedModel, ActionListener listener) { - - var model = parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()); + public void stop(Model model, ActionListener listener) { if (model instanceof ElasticsearchInternalModel esModel) { var serviceSettings = esModel.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index fe83acc8574aa..2fa11f98adf49 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -59,6 +59,7 @@ import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; import java.util.ArrayList; import java.util.Collections; @@ -498,47 +499,40 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M @Override public void checkModelConfig(Model model, ActionListener listener) { - if (model instanceof CustomElandEmbeddingModel elandModel && elandModel.getTaskType() == TaskType.TEXT_EMBEDDING) { - // At this point the inference endpoint configuration has not been persisted yet, if we attempt to do inference using the - // inference id we'll get an error because the trained model code needs to use the persisted inference endpoint to retrieve the - // model id. To get around this we'll have the getEmbeddingSize() method use the model id instead of inference id. So we need - // to create a temporary model that overrides the inference id with the model id. - var temporaryModelWithModelId = new CustomElandEmbeddingModel( - elandModel.getServiceSettings().modelId(), - elandModel.getTaskType(), - elandModel.getConfigurations().getService(), - elandModel.getServiceSettings(), - elandModel.getConfigurations().getChunkingSettings() - ); - - ServiceUtils.getEmbeddingSize( - temporaryModelWithModelId, - this, - listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(elandModel, size))) - ); - } else { - listener.onResponse(model); - } + ModelValidatorBuilder.buildModelValidator(model.getTaskType(), true).validate(this, model, listener); } - private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomElandEmbeddingModel model, int embeddingSize) { - CustomElandInternalTextEmbeddingServiceSettings serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings( - model.getServiceSettings().getNumAllocations(), - model.getServiceSettings().getNumThreads(), - model.getServiceSettings().modelId(), - model.getServiceSettings().getAdaptiveAllocationsSettings(), - embeddingSize, - model.getServiceSettings().similarity(), - model.getServiceSettings().elementType() - ); + @Override + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + if (model instanceof ElasticsearchInternalModel) { + if (model instanceof CustomElandEmbeddingModel embeddingsModel) { + var serviceSettings = embeddingsModel.getServiceSettings(); + + var updatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( + serviceSettings.getNumAllocations(), + serviceSettings.getNumThreads(), + serviceSettings.modelId(), + serviceSettings.getAdaptiveAllocationsSettings(), + embeddingSize, + serviceSettings.similarity(), + serviceSettings.elementType() + ); - return new CustomElandEmbeddingModel( - model.getInferenceEntityId(), - model.getTaskType(), - model.getConfigurations().getService(), - serviceSettings, - model.getConfigurations().getChunkingSettings() - ); + return new CustomElandEmbeddingModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + updatedServiceSettings, + model.getConfigurations().getChunkingSettings() + ); + } else { + // TODO: This is for the E5 case which is text embedding but we didn't previously update the dimensions. Figure out if we do + // need to update the dimensions? + return model; + } + } else { + throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass()); + } } @Override @@ -882,7 +876,10 @@ private List defaultConfigs(boolean useLinuxOptimizedModel) { @Override boolean isDefaultId(String inferenceId) { - return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId); + // return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId); + // TODO: This is a temporary override to ensure that we always deploy models on infer to run a validation call. + // Figure out if this is what we actually want to do? + return true; } static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java new file mode 100644 index 0000000000000..8fefb7e8f3acc --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.validation; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.Model; + +public class ElasticsearchInternalServiceModelValidator implements ModelValidator { + + ModelValidator modelValidator; + + public ElasticsearchInternalServiceModelValidator(ModelValidator modelValidator) { + this.modelValidator = modelValidator; + } + + @Override + public void validate(InferenceService service, Model model, ActionListener listener) { + modelValidator.validate(service, model, listener.delegateResponse((l, exception) -> { + // TODO: Cleanup the below code + service.stop(model, ActionListener.wrap((v) -> listener.onFailure(exception), (e) -> listener.onFailure(exception))); + })); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java index b5bf77cbb3c7d..42815dc1cd806 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java @@ -11,6 +11,15 @@ import org.elasticsearch.inference.TaskType; public class ModelValidatorBuilder { + public static ModelValidator buildModelValidator(TaskType taskType, boolean isElasticsearchInternalService) { + var modelValidator = buildModelValidator(taskType); + if (isElasticsearchInternalService) { + return new ElasticsearchInternalServiceModelValidator(modelValidator); + } else { + return modelValidator; + } + } + public static ModelValidator buildModelValidator(TaskType taskType) { if (taskType == null) { throw new IllegalArgumentException("Task type can't be null"); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 9a4d0dda82238..9f764cc7a97aa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -68,11 +68,13 @@ import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; import org.junit.After; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mockito; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.EnumSet; @@ -1440,7 +1442,7 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() { ); var request = (InferModelAction.Request) invocationOnMock.getArguments()[1]; - assertThat(request.getId(), is("custom-model")); + assertThat(request.getId(), is(randomInferenceEntityId)); return Void.TYPE; }).when(client).execute(eq(InferModelAction.INSTANCE), any(), any()); when(client.threadPool()).thenReturn(threadPool); @@ -1488,6 +1490,84 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() { assertThat(model, is(expectedModel)); } + public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException { + var client = mock(Client.class); + try (var service = createService(client)) { + var model = OpenAiChatCompletionModelTests.createChatCompletionModel( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10) + ); + assertThrows( + ElasticsearchStatusException.class, + () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); } + ); + } + } + + public void testUpdateModelWithEmbeddingDetails_NonElandModelProvided() throws IOException { + var client = mock(Client.class); + try (var service = createService(client)) { + var originalModel = new MultilingualE5SmallModel( + randomAlphaOfLength(10), + TaskType.TEXT_EMBEDDING, + randomAlphaOfLength(10), + new MultilingualE5SmallInternalServiceSettings( + randomNonNegativeInt(), + randomNonNegativeInt(), + randomAlphaOfLength(10), + null + ), + null + ); + + var updatedModel = service.updateModelWithEmbeddingDetails(originalModel, randomNonNegativeInt()); + assertEquals(originalModel, updatedModel); + } + } + + public void testUpdateModelWithEmbeddingDetails_ElandModelProvided() throws IOException { + var client = mock(Client.class); + try (var service = createService(client)) { + var originalServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( + randomNonNegativeInt(), + randomNonNegativeInt(), + randomAlphaOfLength(10), + null + ); + var originalModel = new CustomElandEmbeddingModel( + randomAlphaOfLength(10), + TaskType.TEXT_EMBEDDING, + randomAlphaOfLength(10), + originalServiceSettings, + ChunkingSettingsTests.createRandomChunkingSettings() + ); + + var embeddingSize = randomNonNegativeInt(); + var expectedUpdatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( + originalServiceSettings.getNumAllocations(), + originalServiceSettings.getNumThreads(), + originalServiceSettings.modelId(), + originalServiceSettings.getAdaptiveAllocationsSettings(), + embeddingSize, + originalServiceSettings.similarity(), + originalServiceSettings.elementType() + ); + var expectedUpdatedModel = new CustomElandEmbeddingModel( + originalModel.getInferenceEntityId(), + originalModel.getTaskType(), + originalModel.getConfigurations().getService(), + expectedUpdatedServiceSettings, + originalModel.getConfigurations().getChunkingSettings() + ); + + var actualUpdatedModel = service.updateModelWithEmbeddingDetails(originalModel, embeddingSize); + assertEquals(expectedUpdatedModel, actualUpdatedModel); + } + } + public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() { { assertFalse(