diff --git a/docs/changelog/113216.yaml b/docs/changelog/113216.yaml new file mode 100644 index 0000000000000..dec0b991fdacf --- /dev/null +++ b/docs/changelog/113216.yaml @@ -0,0 +1,10 @@ +pr: 113216 +summary: "[Inference API] Deprecate elser service" +area: Machine Learning +type: deprecation +issues: [] +deprecation: + title: "[Inference API] Deprecate elser service" + area: REST API + details: The `elser` service of the inference API will be removed in an upcoming release. Please use the elasticsearch service instead. + impact: In the current version there is no impact. In a future version, users of the `elser` service will no longer be able to use it, and will be required to use the `elasticsearch` service to access elser through the inference API. diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index 40b4e37f36509..f1ce94173a550 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -46,7 +46,13 @@ public Map getServices() { } public Optional getService(String serviceName) { - return Optional.ofNullable(services.get(serviceName)); + + if ("elser".equals(serviceName)) { // ElserService.NAME before removal + // here we are aliasing the elser service to use the elasticsearch service instead + return Optional.ofNullable(services.get("elasticsearch")); // ElasticsearchInternalService.NAME + } else { + return Optional.ofNullable(services.get(serviceName)); + } } public List getNamedWriteables() { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 524cd5014c19e..ea8b32f36f54c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -28,11 +28,10 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceTests; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettingsTests; import org.junit.Before; import java.io.IOException; @@ -118,10 +117,10 @@ public void testGetModel() throws Exception { assertEquals(model.getConfigurations().getService(), modelHolder.get().service()); - var elserService = new ElserInternalService( + var elserService = new ElasticsearchInternalService( new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class), mock(ThreadPool.class)) ); - ElserInternalModel roundTripModel = elserService.parsePersistedConfigWithSecrets( + ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets( modelHolder.get().inferenceEntityId(), modelHolder.get().taskType(), modelHolder.get().settings(), @@ -277,7 +276,17 @@ public void testGetModelWithSecrets() throws InterruptedException { } private Model buildElserModelConfig(String inferenceEntityId, TaskType taskType) { - return ElserInternalServiceTests.randomModelConfig(inferenceEntityId, taskType); + return switch (taskType) { + case SPARSE_EMBEDDING -> new org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalModel( + inferenceEntityId, + taskType, + ElasticsearchInternalService.NAME, + ElserInternalServiceSettingsTests.createRandom(), + ElserMlNodeTaskSettingsTests.createRandom() + ); + default -> throw new IllegalArgumentException("task type " + taskType + " is not supported"); + }; + } protected void blockingCall(Consumer> function, AtomicReference response, AtomicReference error) @@ -300,7 +309,7 @@ private static Model buildModelWithUnknownField(String inferenceEntityId) { new ModelWithUnknownField( inferenceEntityId, TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, + ElasticsearchInternalService.NAME, ElserInternalServiceSettingsTests.createRandom(), ElserMlNodeTaskSettingsTests.createRandom() ) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 336626cd1db20..02bddb6076d69 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -64,9 +64,9 @@ import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettings; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index f2f019490444e..0ab395f4bfa39 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -86,7 +86,6 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; @@ -229,7 +228,6 @@ public void loadExtensions(ExtensionLoader loader) { public List getInferenceServiceFactories() { return List.of( - ElserInternalService::new, context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()), context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 4186b281a35b5..d2a73b7df77c1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.inference.InferenceService; @@ -42,6 +43,7 @@ public class TransportInferenceAction extends HandledTransportAction serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMap(config, ModelConfigurations.TASK_SETTINGS); + String serviceName = (String) config.remove(ModelConfigurations.SERVICE); // required for elser service in elasticsearch service throwIfNotEmptyMap(config, name()); String modelId = (String) serviceSettingsMap.get(ElasticsearchInternalServiceSettings.MODEL_ID); if (modelId == null) { - throw new ValidationException().addValidationError("Error parsing request config, model id is missing"); - } - if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) { + if (OLD_ELSER_SERVICE_NAME.equals(serviceName)) { + // TODO complete deprecation of null model ID + // throw new ValidationException().addValidationError("Error parsing request config, model id is missing"); + DEPRECATION_LOGGER.critical( + DeprecationCategory.API, + "inference_api_null_model_id_in_elasticsearch_service", + "Putting elasticsearch service inference endpoints (including elser service) without a model_id field is" + + " deprecated and will be removed in a future release. Please specify a model_id field." + ); + platformArch.accept( + modelListener.delegateFailureAndWrap( + (delegate, arch) -> elserCase(inferenceEntityId, taskType, config, arch, serviceSettingsMap, modelListener) + ) + ); + } else { + throw new IllegalArgumentException("Error parsing service settings, model_id must be provided"); + } + } else if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) { platformArch.accept( modelListener.delegateFailureAndWrap( (delegate, arch) -> e5Case(inferenceEntityId, taskType, config, arch, serviceSettingsMap, modelListener) ) ); + } else if (ElserModels.isValidModel(modelId)) { + platformArch.accept( + modelListener.delegateFailureAndWrap( + (delegate, arch) -> elserCase(inferenceEntityId, taskType, config, arch, serviceSettingsMap, modelListener) + ) + ); } else { customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, modelListener); } @@ -239,7 +270,86 @@ static boolean modelVariantValidForArchitecture(Set platformArchitecture // platform agnostic model is always compatible return true; } + return modelId.equals( + selectDefaultModelVariantBasedOnClusterArchitecture( + platformArchitectures, + MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86, + MULTILINGUAL_E5_SMALL_MODEL_ID + ) + ); + } + private void elserCase( + String inferenceEntityId, + TaskType taskType, + Map config, + Set platformArchitectures, + Map serviceSettingsMap, + ActionListener modelListener + ) { + var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap); + final String defaultModelId = selectDefaultModelVariantBasedOnClusterArchitecture( + platformArchitectures, + ELSER_V2_MODEL_LINUX_X86, + ELSER_V2_MODEL + ); + if (false == defaultModelId.equals(esServiceSettingsBuilder.getModelId())) { + + if (esServiceSettingsBuilder.getModelId() == null) { + // TODO remove this case once we remove the option to not pass model ID + esServiceSettingsBuilder.setModelId(defaultModelId); + } else if (esServiceSettingsBuilder.getModelId().equals(ELSER_V2_MODEL)) { + logger.warn( + "The platform agnostic model [{}] was requested on Linux x86_64. " + + "It is recommended to use the optimized model instead [{}]", + ELSER_V2_MODEL, + ELSER_V2_MODEL_LINUX_X86 + ); + } else { + throw new IllegalArgumentException( + "Error parsing request config, model id does not match any models available on this platform. Was [" + + esServiceSettingsBuilder.getModelId() + + "]. You may need to use a platform agnostic model." + ); + } + } + + DEPRECATION_LOGGER.warn( + DeprecationCategory.API, + "inference_api_elser_service", + "The [{}] service is deprecated and will be removed in a future release. Use the [{}] service instead, with" + + " [model_id] set to [{}] in the [service_settings]", + OLD_ELSER_SERVICE_NAME, + ElasticsearchInternalService.NAME, + defaultModelId + ); + + if (modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(platformArchitectures, esServiceSettingsBuilder.getModelId())) { + throw new IllegalArgumentException( + "Error parsing request config, model id does not match any models available on this platform. Was [" + + esServiceSettingsBuilder.getModelId() + + "]" + ); + } + + throwIfNotEmptyMap(config, name()); + throwIfNotEmptyMap(serviceSettingsMap, name()); + + modelListener.onResponse( + new ElserInternalModel( + inferenceEntityId, + taskType, + NAME, + new ElserInternalServiceSettings(esServiceSettingsBuilder.build()), + ElserMlNodeTaskSettings.DEFAULT + ) + ); + } + + private static boolean modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic( + Set platformArchitectures, + String modelId + ) { return modelId.equals( selectDefaultModelVariantBasedOnClusterArchitecture( platformArchitectures, @@ -276,6 +386,14 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M NAME, new MultilingualE5SmallInternalServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)) ); + } else if (ElserModels.isValidModel(modelId)) { + return new ElserInternalModel( + inferenceEntityId, + taskType, + NAME, + new ElserInternalServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)), + ElserMlNodeTaskSettings.DEFAULT + ); } else { return createCustomElandModel( inferenceEntityId, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index 1acf19c5373b7..f8b5837ef387e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -83,7 +83,7 @@ protected static ElasticsearchInternalServiceSettings.Builder fromMap( validationException ); - // model id is optional as the ELSER and E5 service will default it + // model id is optional as the ELSER service will default it. TODO make this a required field once the elser service is removed String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); if (numAllocations == null && adaptiveAllocationsSettings == null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java similarity index 93% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java index bb668c314649d..827eb178f7633 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.elser; +package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; @@ -13,7 +13,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; public class ElserInternalModel extends ElasticsearchInternalModel { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java similarity index 89% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java index fcbabd5a88fc6..f7bcd95c8bd28 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java @@ -5,14 +5,13 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.elser; +package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; import java.io.IOException; import java.util.Arrays; @@ -22,7 +21,7 @@ public class ElserInternalServiceSettings extends ElasticsearchInternalServiceSe public static final String NAME = "elser_mlnode_service_settings"; - public static ElasticsearchInternalServiceSettings.Builder fromRequestMap(Map map) { + public static Builder fromRequestMap(Map map) { ValidationException validationException = new ValidationException(); var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettings.java similarity index 96% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettings.java index 9b9f6e41113e5..934edaa96a15c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettings.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.elser; +package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserModels.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java similarity index 87% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserModels.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java index af94d2813dd2c..37f528ea3a750 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserModels.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.elser; +package org.elasticsearch.xpack.inference.services.elasticsearch; import java.util.Set; @@ -23,7 +23,7 @@ public class ElserModels { ); public static boolean isValidModel(String model) { - return VALID_ELSER_MODEL_IDS.contains(model); + return model != null && VALID_ELSER_MODEL_IDS.contains(model); } public static boolean isValidEisModel(String model) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java deleted file mode 100644 index d36b8eca7661e..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java +++ /dev/null @@ -1,300 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - * - * this file has been contributed to by a Generative AI - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.ChunkingOptions; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceServiceExtension; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; -import org.elasticsearch.xpack.inference.services.ServiceUtils; -import org.elasticsearch.xpack.inference.services.elasticsearch.BaseElasticsearchInternalService; - -import java.util.ArrayList; -import java.util.EnumSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.Consumer; - -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; -import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL; -import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL_LINUX_X86; - -public class ElserInternalService extends BaseElasticsearchInternalService { - - public static final String NAME = "elser"; - - private static final String OLD_MODEL_ID_FIELD_NAME = "model_version"; - - public ElserInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) { - super(context); - } - - // for testing - ElserInternalService( - InferenceServiceExtension.InferenceServiceFactoryContext context, - Consumer>> platformArch - ) { - super(context, platformArch); - } - - @Override - protected EnumSet supportedTaskTypes() { - return EnumSet.of(TaskType.SPARSE_EMBEDDING); - } - - @Override - public void parseRequestConfig( - String inferenceEntityId, - TaskType taskType, - Map config, - ActionListener parsedModelListener - ) { - try { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - var serviceSettingsBuilder = ElserInternalServiceSettings.fromRequestMap(serviceSettingsMap); - - Map taskSettingsMap; - // task settings are optional - if (config.containsKey(ModelConfigurations.TASK_SETTINGS)) { - taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - } else { - taskSettingsMap = Map.of(); - } - - var taskSettings = taskSettingsFromMap(taskType, taskSettingsMap); - - throwIfNotEmptyMap(config, NAME); - throwIfNotEmptyMap(serviceSettingsMap, NAME); - throwIfNotEmptyMap(taskSettingsMap, NAME); - - if (serviceSettingsBuilder.getModelId() == null) { - platformArch.accept(parsedModelListener.delegateFailureAndWrap((delegate, arch) -> { - serviceSettingsBuilder.setModelId( - selectDefaultModelVariantBasedOnClusterArchitecture(arch, ELSER_V2_MODEL_LINUX_X86, ELSER_V2_MODEL) - ); - parsedModelListener.onResponse( - new ElserInternalModel( - inferenceEntityId, - taskType, - NAME, - new ElserInternalServiceSettings(serviceSettingsBuilder.build()), - taskSettings - ) - ); - })); - } else { - parsedModelListener.onResponse( - new ElserInternalModel( - inferenceEntityId, - taskType, - NAME, - new ElserInternalServiceSettings(serviceSettingsBuilder.build()), - taskSettings - ) - ); - } - } catch (Exception e) { - parsedModelListener.onFailure(e); - } - } - - @Override - public ElserInternalModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - return parsePersistedConfig(inferenceEntityId, taskType, config); - } - - @Override - public ElserInternalModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - - // Change from old model_version field name to new model_id field name as of - // TransportVersions.ML_TEXT_EMBEDDING_INFERENCE_SERVICE_ADDED - if (serviceSettingsMap.containsKey(OLD_MODEL_ID_FIELD_NAME)) { - String modelId = ServiceUtils.removeAsType(serviceSettingsMap, OLD_MODEL_ID_FIELD_NAME, String.class); - serviceSettingsMap.put(ElserInternalServiceSettings.MODEL_ID, modelId); - } - - var serviceSettings = ElserInternalServiceSettings.fromPersistedMap(serviceSettingsMap); - - Map taskSettingsMap; - // task settings are optional - if (config.containsKey(ModelConfigurations.TASK_SETTINGS)) { - taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - } else { - taskSettingsMap = Map.of(); - } - - var taskSettings = taskSettingsFromMap(taskType, taskSettingsMap); - - return new ElserInternalModel(inferenceEntityId, taskType, NAME, new ElserInternalServiceSettings(serviceSettings), taskSettings); - } - - @Override - public void infer( - Model model, - @Nullable String query, - List inputs, - boolean stream, - Map taskSettings, - InputType inputType, - TimeValue timeout, - ActionListener listener - ) { - // No task settings to override with requestTaskSettings - - try { - checkCompatibleTaskType(model.getConfigurations().getTaskType()); - } catch (Exception e) { - listener.onFailure(e); - return; - } - - var request = buildInferenceRequest( - model.getConfigurations().getInferenceEntityId(), - TextExpansionConfigUpdate.EMPTY_UPDATE, - inputs, - inputType, - timeout, - false // chunk - ); - - client.execute( - InferModelAction.INSTANCE, - request, - listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(SparseEmbeddingResults.of(inferenceResult.getInferenceResults())) - ) - ); - } - - public void chunkedInfer( - Model model, - List input, - Map taskSettings, - InputType inputType, - @Nullable ChunkingOptions chunkingOptions, - TimeValue timeout, - ActionListener> listener - ) { - chunkedInfer(model, null, input, taskSettings, inputType, chunkingOptions, timeout, listener); - } - - @Override - public void chunkedInfer( - Model model, - @Nullable String query, - List inputs, - Map taskSettings, - InputType inputType, - @Nullable ChunkingOptions chunkingOptions, - TimeValue timeout, - ActionListener> listener - ) { - try { - checkCompatibleTaskType(model.getConfigurations().getTaskType()); - } catch (Exception e) { - listener.onFailure(e); - return; - } - - var configUpdate = chunkingOptions != null - ? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span()) - : new TokenizationConfigUpdate(null, null); - - var request = buildInferenceRequest( - model.getConfigurations().getInferenceEntityId(), - configUpdate, - inputs, - inputType, - timeout, - true // chunk - ); - - client.execute( - InferModelAction.INSTANCE, - request, - listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(translateChunkedResults(inferenceResult.getInferenceResults())) - ) - ); - } - - private void checkCompatibleTaskType(TaskType taskType) { - if (TaskType.SPARSE_EMBEDDING.isAnyOrSame(taskType) == false) { - throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); - } - } - - private static ElserMlNodeTaskSettings taskSettingsFromMap(TaskType taskType, Map config) { - if (taskType != TaskType.SPARSE_EMBEDDING) { - throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); - } - - // no config options yet - return ElserMlNodeTaskSettings.DEFAULT; - } - - private List translateChunkedResults(List inferenceResults) { - var translated = new ArrayList(); - - for (var inferenceResult : inferenceResults) { - if (inferenceResult instanceof MlChunkedTextExpansionResults mlChunkedResult) { - translated.add(InferenceChunkedSparseEmbeddingResults.ofMlResult(mlChunkedResult)); - } else if (inferenceResult instanceof ErrorInferenceResults error) { - translated.add(new ErrorChunkedInferenceResults(error.getException())); - } else { - throw new ElasticsearchStatusException( - "Expected a chunked inference [{}] received [{}]", - RestStatus.INTERNAL_SERVER_ERROR, - MlChunkedTextExpansionResults.NAME, - inferenceResult.getWriteableName() - ); - } - } - return translated; - } - - @Override - public String name() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_8_12_0; - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java index 5a1922fd200f5..03613901c7816 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java @@ -16,8 +16,8 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings; public class ModelConfigurationsTests extends AbstractWireSerializingTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java index af13ce7944685..c9f4234331221 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java @@ -11,7 +11,7 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; public class ElasticInferenceServiceSparseEmbeddingsModelTests extends ESTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java index a2b36cf9abdd5..1751e1c3be5e8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java @@ -16,13 +16,13 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import java.io.IOException; import java.util.HashMap; import java.util.Map; -import static org.elasticsearch.xpack.inference.services.elser.ElserModelsTests.randomElserModel; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModelsTests.randomElserModel; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index ab85e112418f5..d10c70c6f0f5e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -36,7 +36,7 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.ServiceFields; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java index 41afef88d22c6..419db748d793d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; import java.io.IOException; import java.util.HashMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index de9298f1b08dd..cd6da4c0ad8d8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -9,10 +9,12 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; +import org.apache.logging.log4j.Level; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -69,6 +71,8 @@ import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -97,9 +101,11 @@ public void shutdownThreadPool() { } public void testParseRequestConfig() { + // Null model variant var service = createService(mock(Client.class)); - var settings = new HashMap(); - settings.put( + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, ElasticsearchInternalService.NAME); + config.put( ModelConfigurations.SERVICE_SETTINGS, new HashMap<>( Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4) @@ -112,15 +118,16 @@ public void testParseRequestConfig() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, config, modelListener); } public void testParseRequestConfig_Misconfigured() { - // Null model variant + // Non-existent model variant { var service = createService(mock(Client.class)); - var settings = new HashMap(); - settings.put( + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, ElasticsearchInternalService.NAME); + config.put( ModelConfigurations.SERVICE_SETTINGS, new HashMap<>( Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4) @@ -133,20 +140,21 @@ public void testParseRequestConfig_Misconfigured() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, config, modelListener); } // Invalid config map { var service = createService(mock(Client.class)); - var settings = new HashMap(); - settings.put( + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, ElasticsearchInternalService.NAME); + config.put( ModelConfigurations.SERVICE_SETTINGS, new HashMap<>( Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4) ) ); - settings.put("not_a_valid_config_setting", randomAlphaOfLength(10)); + config.put("not_a_valid_config_setting", randomAlphaOfLength(10)); ActionListener modelListener = ActionListener.wrap( model -> fail("Model parsing should have failed"), @@ -154,7 +162,7 @@ public void testParseRequestConfig_Misconfigured() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, config, modelListener); } } @@ -182,7 +190,7 @@ public void testParseRequestConfig_E5() { randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings, - getModelVerificationActionListener(e5ServiceSettings) + getE5ModelVerificationActionListener(e5ServiceSettings) ); } @@ -214,7 +222,7 @@ public void testParseRequestConfig_E5() { randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings, - getModelVerificationActionListener(e5ServiceSettings) + getE5ModelVerificationActionListener(e5ServiceSettings) ); } @@ -247,6 +255,106 @@ public void testParseRequestConfig_E5() { } } + public void testParseRequestConfig_elser() { + // General happy case + { + Client mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(threadPool); + var service = createService(mockClient); + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, OLD_ELSER_SERVICE_NAME); + config.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, + 1, + ElasticsearchInternalServiceSettings.NUM_THREADS, + 4, + ElasticsearchInternalServiceSettings.MODEL_ID, + ElserModels.ELSER_V2_MODEL + ) + ) + ); + + var elserServiceSettings = new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null); + + service.parseRequestConfig( + randomInferenceEntityId, + TaskType.SPARSE_EMBEDDING, + config, + getElserModelVerificationActionListener( + elserServiceSettings, + null, + "The [elser] service is deprecated and will be removed in a future release. Use the [elasticsearch] service " + + "instead, with [model_id] set to [.elser_model_2] in the [service_settings]" + ) + ); + } + + // null model ID returns elser model for the provided platform (not linux) + { + Client mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(threadPool); + var service = createService(mockClient); + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, OLD_ELSER_SERVICE_NAME); + config.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4) + ) + ); + + var elserServiceSettings = new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null); + + String criticalWarning = + "Putting elasticsearch service inference endpoints (including elser service) without a model_id field is" + + " deprecated and will be removed in a future release. Please specify a model_id field."; + String warnWarning = + "The [elser] service is deprecated and will be removed in a future release. Use the [elasticsearch] service " + + "instead, with [model_id] set to [.elser_model_2] in the [service_settings]"; + service.parseRequestConfig( + randomInferenceEntityId, + TaskType.SPARSE_EMBEDDING, + config, + getElserModelVerificationActionListener(elserServiceSettings, criticalWarning, warnWarning) + ); + assertWarnings(true, new DeprecationWarning(DeprecationLogger.CRITICAL, criticalWarning)); + } + + // Invalid service settings + { + Client mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(threadPool); + var service = createService(mockClient); + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, OLD_ELSER_SERVICE_NAME); + config.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, + 1, + ElasticsearchInternalServiceSettings.NUM_THREADS, + 4, + ElasticsearchInternalServiceSettings.MODEL_ID, + ElserModels.ELSER_V2_MODEL, + "not_a_valid_service_setting", + randomAlphaOfLength(10) + ) + ) + ); + + ActionListener modelListener = ActionListener.wrap( + model -> fail("Model parsing should have failed"), + e -> assertThat(e, instanceOf(ElasticsearchStatusException.class)) + ); + + service.parseRequestConfig(randomInferenceEntityId, TaskType.SPARSE_EMBEDDING, config, modelListener); + } + } + @SuppressWarnings("unchecked") public void testParseRequestConfig_Rerank() { // with task settings @@ -374,7 +482,7 @@ public void testParseRequestConfig_SparseEmbedding() { service.parseRequestConfig(randomInferenceEntityId, TaskType.SPARSE_EMBEDDING, settings, modelListener); } - private ActionListener getModelVerificationActionListener(MultilingualE5SmallInternalServiceSettings e5ServiceSettings) { + private ActionListener getE5ModelVerificationActionListener(MultilingualE5SmallInternalServiceSettings e5ServiceSettings) { return ActionListener.wrap(model -> { assertEquals( new MultilingualE5SmallModel( @@ -388,6 +496,30 @@ private ActionListener getModelVerificationActionListener(MultilingualE5S }, e -> { fail("Model parsing failed " + e.getMessage()); }); } + private ActionListener getElserModelVerificationActionListener( + ElserInternalServiceSettings elserServiceSettings, + String criticalWarning, + String warnWarning + ) { + return ActionListener.wrap(model -> { + assertWarnings( + true, + new DeprecationWarning(DeprecationLogger.CRITICAL, criticalWarning), + new DeprecationWarning(Level.WARN, warnWarning) + ); + assertEquals( + new ElserInternalModel( + randomInferenceEntityId, + TaskType.SPARSE_EMBEDDING, + NAME, + elserServiceSettings, + ElserMlNodeTaskSettings.DEFAULT + ), + model + ); + }, e -> { fail("Model parsing failed " + e.getMessage()); }); + } + public void testParsePersistedConfig() { // Null model variant diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java similarity index 89% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java index ffbdf1a5a6178..f4e97b2c2e5e0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java @@ -5,18 +5,16 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.elser; +package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettingsTests; import java.io.IOException; import java.util.HashSet; -import static org.elasticsearch.xpack.inference.services.elser.ElserModelsTests.randomElserModel; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModelsTests.randomElserModel; public class ElserInternalServiceSettingsTests extends AbstractWireSerializingTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java similarity index 93% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettingsTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java index d55065a5f9b27..a7de3fe8b8fdc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.elser; +package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java new file mode 100644 index 0000000000000..fa0148ac69df5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.test.ESTestCase; + +public class ElserModelsTests extends ESTestCase { + + public static String randomElserModel() { + return randomFrom(org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.VALID_ELSER_MODEL_IDS); + } + + public void testIsValidModel() { + assertTrue(org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidModel(randomElserModel())); + } + + public void testIsValidEisModel() { + assertTrue( + org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidEisModel( + org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL + ) + ); + } + + public void testIsInvalidModel() { + assertFalse(org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidModel("invalid")); + } + + public void testIsInvalidEisModel() { + assertFalse( + org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL_LINUX_X86) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java deleted file mode 100644 index 09abeb9b9b389..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java +++ /dev/null @@ -1,548 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - * - * this file was contributed to by a generative AI - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.ChunkingOptions; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceServiceExtension; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceChunkedTextExpansionResultsTests; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; -import org.elasticsearch.xpack.inference.InferencePlugin; -import org.junit.After; -import org.junit.Before; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; - -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class ElserInternalServiceTests extends ESTestCase { - - private static ThreadPool threadPool; - - @Before - public void setUpThreadPool() { - threadPool = createThreadPool(InferencePlugin.inferenceUtilityExecutor(Settings.EMPTY)); - } - - @After - public void shutdownThreadPool() { - TestThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); - } - - public static Model randomModelConfig(String inferenceEntityId, TaskType taskType) { - return switch (taskType) { - case SPARSE_EMBEDDING -> new ElserInternalModel( - inferenceEntityId, - taskType, - ElserInternalService.NAME, - ElserInternalServiceSettingsTests.createRandom(), - ElserMlNodeTaskSettingsTests.createRandom() - ); - default -> throw new IllegalArgumentException("task type " + taskType + " is not supported"); - }; - } - - public void testParseConfigStrict() { - var service = createService(mock(Client.class)); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - "model_id", - ".elser_model_1" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - - var expectedModel = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ".elser_model_1", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var modelVerificationListener = getModelVerificationListener(expectedModel); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelVerificationListener); - - } - - public void testParseConfigWithoutModelId() { - Client mockClient = mock(Client.class); - when(mockClient.threadPool()).thenReturn(threadPool); - var service = createService(mockClient); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>(Map.of(ElserInternalServiceSettings.NUM_ALLOCATIONS, 1, ElserInternalServiceSettings.NUM_THREADS, 4)) - ); - - var expectedModel = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ".elser_model_2", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var modelVerificationListener = getModelVerificationListener(expectedModel); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelVerificationListener); - - } - - public void testParseConfigLooseWithOldModelId() { - var service = createService(mock(Client.class)); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - "model_version", - ".elser_model_1" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - - var expectedModel = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ".elser_model_1", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var realModel = service.parsePersistedConfig("foo", TaskType.SPARSE_EMBEDDING, settings); - - assertEquals(expectedModel, realModel); - - } - - private static ActionListener getModelVerificationListener(ElserInternalModel expectedModel) { - return ActionListener.wrap( - (model) -> { assertEquals(expectedModel, model); }, - (e) -> fail("Model verification should not fail " + e.getMessage()) - ); - } - - public void testParseConfigStrictWithNoTaskSettings() { - var service = createService(mock(Client.class), Set.of("Aarch64")); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>(Map.of(ElserInternalServiceSettings.NUM_ALLOCATIONS, 1, ElserInternalServiceSettings.NUM_THREADS, 4)) - ); - - var expectedModel = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var modelVerificationListener = getModelVerificationListener(expectedModel); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelVerificationListener); - } - - public void testParseConfigStrictWithUnknownSettings() { - - var service = createService(mock(Client.class)); - - for (boolean throwOnUnknown : new boolean[] { true, false }) { - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - settings.put("foo", "bar"); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, errorVerificationListener); - } - } - - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, errorVerificationListener); - } - } - - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2", - "foo", - "bar" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, errorVerificationListener); - } - } - } - } - - public void testParseRequestConfig_DefaultModel() { - { - var service = createService(mock(Client.class), Set.of()); - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>(Map.of(ElserInternalServiceSettings.NUM_ALLOCATIONS, 1, ElserInternalServiceSettings.NUM_THREADS, 4)) - ); - - ActionListener modelActionListener = ActionListener.wrap((model) -> { - assertEquals(".elser_model_2", ((ElserInternalModel) model).getServiceSettings().modelId()); - }, (e) -> { fail(e, "Model verification should not fail"); }); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelActionListener); - } - { - var service = createService(mock(Client.class), Set.of("linux-x86_64")); - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>(Map.of(ElserInternalServiceSettings.NUM_ALLOCATIONS, 1, ElserInternalServiceSettings.NUM_THREADS, 4)) - ); - - ActionListener modelActionListener = ActionListener.wrap((model) -> { - assertEquals(".elser_model_2_linux-x86_64", ((ElserInternalModel) model).getServiceSettings().modelId()); - }, (e) -> { fail(e, "Model verification should not fail"); }); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, modelActionListener); - } - } - - @SuppressWarnings("unchecked") - public void testChunkInfer() { - var mlTrainedModelResults = new ArrayList(); - mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); - mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); - mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); - var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); - - ThreadPool threadpool = new TestThreadPool("test"); - Client client = mock(Client.class); - when(client.threadPool()).thenReturn(threadpool); - doAnswer(invocationOnMock -> { - var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(response); - return null; - }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); - - var model = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - "elser", - new ElserInternalServiceSettings(1, 1, "elser", null), - new ElserMlNodeTaskSettings() - ); - var service = createService(client); - - var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { - assertThat(chunkedResponse, hasSize(3)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); - assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(0)).getChunks(), result1.getChunkedResults()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); - assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(1)).getChunks(), result2.getChunkedResults()); - var result3 = (ErrorChunkedInferenceResults) chunkedResponse.get(2); - assertThat(result3.getException(), instanceOf(RuntimeException.class)); - assertThat(result3.getException().getMessage(), containsString("boom")); - gotResults.set(true); - }, ESTestCase::fail); - - service.chunkedInfer( - model, - null, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - new ChunkingOptions(null, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadpool)) - ); - - if (gotResults.get() == false) { - terminate(threadpool); - } - assertTrue("Listener not called", gotResults.get()); - } - - @SuppressWarnings("unchecked") - public void testChunkInferSetsTokenization() { - var expectedSpan = new AtomicInteger(); - var expectedWindowSize = new AtomicReference(); - - ThreadPool threadpool = new TestThreadPool("test"); - Client client = mock(Client.class); - try { - when(client.threadPool()).thenReturn(threadpool); - doAnswer(invocationOnMock -> { - var request = (InferTrainedModelDeploymentAction.Request) invocationOnMock.getArguments()[1]; - assertThat(request.getUpdate(), instanceOf(TokenizationConfigUpdate.class)); - var update = (TokenizationConfigUpdate) request.getUpdate(); - assertEquals(update.getSpanSettings().span(), expectedSpan.get()); - assertEquals(update.getSpanSettings().maxSequenceLength(), expectedWindowSize.get()); - return null; - }).when(client) - .execute( - same(InferTrainedModelDeploymentAction.INSTANCE), - any(InferTrainedModelDeploymentAction.Request.class), - any(ActionListener.class) - ); - - var model = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - "elser", - new ElserInternalServiceSettings(1, 1, "elser", null), - new ElserMlNodeTaskSettings() - ); - var service = createService(client); - - expectedSpan.set(-1); - expectedWindowSize.set(null); - service.chunkedInfer( - model, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - null, - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) - ); - - expectedSpan.set(-1); - expectedWindowSize.set(256); - service.chunkedInfer( - model, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - new ChunkingOptions(256, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) - ); - } finally { - terminate(threadpool); - } - } - - @SuppressWarnings("unchecked") - public void testPutModel() { - var client = mock(Client.class); - ArgumentCaptor argument = ArgumentCaptor.forClass(PutTrainedModelAction.Request.class); - - doAnswer(invocation -> { - var listener = (ActionListener) invocation.getArguments()[2]; - listener.onResponse(new PutTrainedModelAction.Response(mock(TrainedModelConfig.class))); - return null; - }).when(client).execute(Mockito.same(PutTrainedModelAction.INSTANCE), argument.capture(), any()); - - when(client.threadPool()).thenReturn(threadPool); - - var service = createService(client); - - var model = new ElserInternalModel( - "my-elser", - TaskType.SPARSE_EMBEDDING, - "elser", - new ElserInternalServiceSettings(1, 1, ".elser_model_2", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - service.putModel(model, new ActionListener<>() { - @Override - public void onResponse(Boolean success) { - assertTrue(success); - } - - @Override - public void onFailure(Exception e) { - fail(e); - } - }); - - var putConfig = argument.getValue().getTrainedModelConfig(); - assertEquals("text_field", putConfig.getInput().getFieldNames().get(0)); - } - - private ElserInternalService createService(Client client) { - var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool); - return new ElserInternalService(context); - } - - private ElserInternalService createService(Client client, Set architectures) { - var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool); - return new ElserInternalService(context, (l) -> l.onResponse(architectures)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java deleted file mode 100644 index f56e941dcc8c0..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import org.elasticsearch.test.ESTestCase; - -public class ElserModelsTests extends ESTestCase { - - public static String randomElserModel() { - return randomFrom(ElserModels.VALID_ELSER_MODEL_IDS); - } - - public void testIsValidModel() { - assertTrue(ElserModels.isValidModel(randomElserModel())); - } - - public void testIsValidEisModel() { - assertTrue(ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL)); - } - - public void testIsInvalidModel() { - assertFalse(ElserModels.isValidModel("invalid")); - } - - public void testIsInvalidEisModel() { - assertFalse(ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL_LINUX_X86)); - } -}