From 99f648cb17b5bb06fa5e2f5a1a6b61337df489db Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 9 Sep 2024 16:02:39 -0400 Subject: [PATCH 01/22] merging --- .../ElasticsearchInternalService.java | 61 ++++++++++++++++--- .../ElasticsearchInternalServiceSettings.java | 4 +- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 93408c067098b..1d04f99e85419 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; @@ -54,6 +56,8 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL; +import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL_LINUX_X86; public class ElasticsearchInternalService extends BaseElasticsearchInternalService { @@ -65,6 +69,9 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi MULTILINGUAL_E5_SMALL_MODEL_ID, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 ); + public static final Set ELSER_VALID_IDS = Set.of(ELSER_V2_MODEL_LINUX_X86, ELSER_V2_MODEL); + + private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) { super(context); @@ -95,6 +102,8 @@ public void parseRequestConfig( } if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) { e5Case(inferenceEntityId, taskType, config, platformArchitectures, serviceSettingsMap, modelListener); + } else if (ELSER_VALID_IDS.contains(modelId)) { + elserCase(inferenceEntityId, taskType, config, platformArchitectures, serviceSettingsMap, modelListener); } else { customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, modelListener); } @@ -193,16 +202,54 @@ private void e5Case( ) { var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap); - if (esServiceSettingsBuilder.getModelId() == null) { - esServiceSettingsBuilder.setModelId( - selectDefaultModelVariantBasedOnClusterArchitecture( - platformArchitectures, - MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86, - MULTILINGUAL_E5_SMALL_MODEL_ID - ) + if (modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(platformArchitectures, esServiceSettingsBuilder.getModelId())) { + throw new IllegalArgumentException( + "Error parsing request config, model id does not match any models available on this platform. Was [" + + esServiceSettingsBuilder.getModelId() + + "]" ); } + throwIfNotEmptyMap(config, name()); + throwIfNotEmptyMap(serviceSettingsMap, name()); + + modelListener.onResponse( + new MultilingualE5SmallModel( + inferenceEntityId, + taskType, + NAME, + new MultilingualE5SmallInternalServiceSettings(esServiceSettingsBuilder.build()) + ) + ); + } + + private void elserCase( + String inferenceEntityId, + TaskType taskType, + Map config, + Set platformArchitectures, + Map serviceSettingsMap, + ActionListener modelListener + ) { + var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap); + + if (false == esServiceSettingsBuilder.getModelId() + .equals(selectDefaultModelVariantBasedOnClusterArchitecture(platformArchitectures, ELSER_V2_MODEL_LINUX_X86, ELSER_V2_MODEL))) { + if (esServiceSettingsBuilder.getModelId().equals(ELSER_V2_MODEL)) { + logger.warn( + "The platform agnostic model [{}] was requested on Linux x86_64. It is recommended to use the optimized model instead [{}]", + ELSER_V2_MODEL, + ELSER_V2_MODEL_LINUX_X86 + ); + } else { + throw new IllegalArgumentException( + "Error parsing request config, model id does not match any models available on this platform. Was [" + + esServiceSettingsBuilder.getModelId() + + "]. You may need to use a platform agnostic model." + ); + } + } + if (modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(platformArchitectures, esServiceSettingsBuilder.getModelId())) { throw new IllegalArgumentException( "Error parsing request config, model id does not match any models available on this platform. Was [" diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index 1acf19c5373b7..ff1088b6473b5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -27,6 +27,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; public class ElasticsearchInternalServiceSettings implements ServiceSettings { @@ -83,8 +84,7 @@ protected static ElasticsearchInternalServiceSettings.Builder fromMap( validationException ); - // model id is optional as the ELSER and E5 service will default it - String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); if (numAllocations == null && adaptiveAllocationsSettings == null) { validationException.addValidationError( From bc9c7d35fadea5c5fc37a538b71c16b1445d0dbc Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 19 Sep 2024 14:09:46 -0400 Subject: [PATCH 02/22] copy elser service files into elasticsearch service --- .../integration/ModelRegistryIT.java | 6 +- .../ElasticsearchInternalService.java | 4 +- .../elasticsearch/ElserInternalModel.java | 67 +++ .../ElserInternalServiceSettings.java | 71 +++ .../services/elasticsearch/ElserModels.java | 33 ++ .../inference/ModelConfigurationsTests.java | 3 +- ...eSparseEmbeddingsServiceSettingsTests.java | 2 +- .../ElserInternalServiceSettingsTests.java | 93 +++ .../ElserInternalServiceTests.java | 548 ++++++++++++++++++ .../ElserMlNodeTaskSettingsTests.java | 34 ++ .../elasticsearch/ElserModelsTests.java | 38 ++ 11 files changed, 891 insertions(+), 8 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 5157683f2dce9..1713290a139c7 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -29,9 +29,9 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceTests; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceTests; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettingsTests; import org.junit.Before; import java.io.IOException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 1d04f99e85419..4901ed41f2958 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -56,6 +56,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.VALID_ELSER_MODEL_IDS; import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL; import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL_LINUX_X86; @@ -69,7 +70,6 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi MULTILINGUAL_E5_SMALL_MODEL_ID, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 ); - public static final Set ELSER_VALID_IDS = Set.of(ELSER_V2_MODEL_LINUX_X86, ELSER_V2_MODEL); private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); @@ -102,7 +102,7 @@ public void parseRequestConfig( } if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) { e5Case(inferenceEntityId, taskType, config, platformArchitectures, serviceSettingsMap, modelListener); - } else if (ELSER_VALID_IDS.contains(modelId)) { + } else if (VALID_ELSER_MODEL_IDS.contains(modelId)) { elserCase(inferenceEntityId, taskType, config, platformArchitectures, serviceSettingsMap, modelListener); } else { customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, modelListener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java new file mode 100644 index 0000000000000..0389411d433a8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; +import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; + +public class ElserInternalModel extends ElasticsearchInternalModel { + + public ElserInternalModel( + String inferenceEntityId, + TaskType taskType, + String service, + ElserInternalServiceSettings serviceSettings, + ElserMlNodeTaskSettings taskSettings + ) { + super(inferenceEntityId, taskType, service, serviceSettings, taskSettings); + } + + @Override + public ElserInternalServiceSettings getServiceSettings() { + return (ElserInternalServiceSettings) super.getServiceSettings(); + } + + @Override + public ElserMlNodeTaskSettings getTaskSettings() { + return (ElserMlNodeTaskSettings) super.getTaskSettings(); + } + + @Override + public ActionListener getCreateTrainedModelAssignmentActionListener( + Model model, + ActionListener listener + ) { + return new ActionListener<>() { + @Override + public void onResponse(CreateTrainedModelAssignmentAction.Response response) { + listener.onResponse(Boolean.TRUE); + } + + @Override + public void onFailure(Exception e) { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + listener.onFailure( + new ResourceNotFoundException( + "Could not start the ELSER service as the ELSER model for this platform cannot be found." + + " ELSER needs to be downloaded before it can be started." + ) + ); + return; + } + listener.onFailure(e); + } + }; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java new file mode 100644 index 0000000000000..f5199599fa73e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.inference.services.elser.ElserModels; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Map; + +public class ElserInternalServiceSettings extends ElasticsearchInternalServiceSettings { + + public static final String NAME = "elser_mlnode_service_settings"; + + public static Builder fromRequestMap(Map map) { + ValidationException validationException = new ValidationException(); + var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException); + + String modelId = baseSettings.getModelId(); + if (modelId != null && ElserModels.isValidModel(modelId) == false) { + var ve = new ValidationException(); + ve.addValidationError( + "Unknown ELSER model ID [" + modelId + "]. Valid models are " + Arrays.toString(ElserModels.VALID_ELSER_MODEL_IDS.toArray()) + ); + throw ve; + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return baseSettings; + } + + public ElserInternalServiceSettings(ElasticsearchInternalServiceSettings other) { + super(other); + } + + public ElserInternalServiceSettings( + Integer numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + this(new ElasticsearchInternalServiceSettings(numAllocations, numThreads, modelId, adaptiveAllocationsSettings)); + } + + public ElserInternalServiceSettings(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return ElserInternalServiceSettings.NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_11_X; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java new file mode 100644 index 0000000000000..44ee0f394354e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import java.util.Set; + +public class ElserModels { + + public static final String ELSER_V1_MODEL = ".elser_model_1"; + // Default non platform specific v2 model + public static final String ELSER_V2_MODEL = ".elser_model_2"; + public static final String ELSER_V2_MODEL_LINUX_X86 = ".elser_model_2_linux-x86_64"; + + public static Set VALID_ELSER_MODEL_IDS = Set.of( + ElserModels.ELSER_V1_MODEL, + ElserModels.ELSER_V2_MODEL, + ElserModels.ELSER_V2_MODEL_LINUX_X86 + ); + + public static boolean isValidModel(String model) { + return VALID_ELSER_MODEL_IDS.contains(model); + } + + public static boolean isValidEisModel(String model) { + return ELSER_V2_MODEL.equals(model); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java index 5a1922fd200f5..1a55178db28b4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java @@ -14,10 +14,9 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; public class ModelConfigurationsTests extends AbstractWireSerializingTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java index a2b36cf9abdd5..0d54944880fd5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java @@ -22,7 +22,7 @@ import java.util.HashMap; import java.util.Map; -import static org.elasticsearch.xpack.inference.services.elser.ElserModelsTests.randomElserModel; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModelsTests.randomElserModel; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java new file mode 100644 index 0000000000000..abc1c848a3e0d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; +import org.elasticsearch.xpack.inference.services.elser.ElserModels; + +import java.io.IOException; +import java.util.HashSet; + +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModelsTests.randomElserModel; + +public class ElserInternalServiceSettingsTests extends AbstractWireSerializingTestCase< + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings> { + + public static org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings createRandom() { + return new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings( + ElasticsearchInternalServiceSettingsTests.validInstance(randomElserModel()) + ); + } + + public void testBwcWrite() throws IOException { + { + var settings = new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings( + new ElasticsearchInternalServiceSettings(1, 1, ".elser_model_1", null) + ); + var copy = copyInstance(settings, TransportVersions.V_8_12_0); + assertEquals(settings, copy); + } + { + var settings = new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings( + new ElasticsearchInternalServiceSettings(1, 1, ".elser_model_1", null) + ); + var copy = copyInstance(settings, TransportVersions.V_8_11_X); + assertEquals(settings, copy); + } + } + + @Override + protected Writeable.Reader instanceReader() { + return org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings::new; + } + + @Override + protected org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings mutateInstance( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings instance + ) { + return switch (randomIntBetween(0, 2)) { + case 0 -> new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings( + new ElasticsearchInternalServiceSettings( + instance.getNumAllocations() == null ? 1 : instance.getNumAllocations() + 1, + instance.getNumThreads(), + instance.modelId(), + null + ) + ); + case 1 -> new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings( + new ElasticsearchInternalServiceSettings( + instance.getNumAllocations(), + instance.getNumThreads() + 1, + instance.modelId(), + null + ) + ); + case 2 -> { + var versions = new HashSet<>(ElserModels.VALID_ELSER_MODEL_IDS); + versions.remove(instance.modelId()); + yield new ElserInternalServiceSettings( + new ElasticsearchInternalServiceSettings( + instance.getNumAllocations(), + instance.getNumThreads(), + versions.iterator().next(), + null + ) + ); + } + default -> throw new IllegalStateException(); + }; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java new file mode 100644 index 0000000000000..72849d3539343 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java @@ -0,0 +1,548 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceChunkedTextExpansionResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; +import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; +import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; +import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; +import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; +import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.junit.After; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ElserInternalServiceTests extends ESTestCase { + + private static ThreadPool threadPool; + + @Before + public void setUpThreadPool() { + threadPool = new TestThreadPool("test"); + } + + @After + public void shutdownThreadPool() { + TestThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + } + + public static Model randomModelConfig(String inferenceEntityId, TaskType taskType) { + return switch (taskType) { + case SPARSE_EMBEDDING -> new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( + inferenceEntityId, + taskType, + ElserInternalService.NAME, + ElserInternalServiceSettingsTests.createRandom(), + ElserMlNodeTaskSettingsTests.createRandom() + ); + default -> throw new IllegalArgumentException("task type " + taskType + " is not supported"); + }; + } + + public void testParseConfigStrict() { + var service = createService(mock(Client.class)); + + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4, + "model_id", + ".elser_model_1" + ) + ) + ); + settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); + + var expectedModel = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( + "foo", + TaskType.SPARSE_EMBEDDING, + ElserInternalService.NAME, + new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 4, ".elser_model_1", null), + ElserMlNodeTaskSettings.DEFAULT + ); + + var modelVerificationListener = getModelVerificationListener(expectedModel); + + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelVerificationListener); + + } + + public void testParseConfigLooseWithOldModelId() { + var service = createService(mock(Client.class)); + + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4, + "model_version", + ".elser_model_1" + ) + ) + ); + settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); + + var expectedModel = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( + "foo", + TaskType.SPARSE_EMBEDDING, + ElserInternalService.NAME, + new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 4, ".elser_model_1", null), + ElserMlNodeTaskSettings.DEFAULT + ); + + var realModel = service.parsePersistedConfig("foo", TaskType.SPARSE_EMBEDDING, settings); + + assertEquals(expectedModel, realModel); + + } + + private static ActionListener getModelVerificationListener( + org.elasticsearch.xpack.inference.services.elser.ElserInternalModel expectedModel + ) { + return ActionListener.wrap( + (model) -> { assertEquals(expectedModel, model); }, + (e) -> fail("Model verification should not fail " + e.getMessage()) + ); + } + + public void testParseConfigStrictWithNoTaskSettings() { + var service = createService(mock(Client.class)); + + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4 + ) + ) + ); + + var expectedModel = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( + "foo", + TaskType.SPARSE_EMBEDDING, + ElserInternalService.NAME, + new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null), + ElserMlNodeTaskSettings.DEFAULT + ); + + var modelVerificationListener = getModelVerificationListener(expectedModel); + + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelVerificationListener); + + } + + public void testParseConfigStrictWithUnknownSettings() { + + var service = createService(mock(Client.class)); + + for (boolean throwOnUnknown : new boolean[] { true, false }) { + { + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.MODEL_ID, + ".elser_model_2" + ) + ) + ); + settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); + settings.put("foo", "bar"); + + ActionListener errorVerificationListener = ActionListener.wrap((model) -> { + if (throwOnUnknown) { + fail("Model verification should fail when throwOnUnknown is true"); + } + }, (e) -> { + if (throwOnUnknown) { + assertThat( + e.getMessage(), + containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") + ); + } else { + fail("Model verification should not fail when throwOnUnknown is false"); + } + }); + + if (throwOnUnknown == false) { + var parsed = service.parsePersistedConfigWithSecrets( + "foo", + TaskType.SPARSE_EMBEDDING, + settings, + Collections.emptyMap() + ); + } else { + + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); + } + } + + { + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.MODEL_ID, + ".elser_model_2" + ) + ) + ); + settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); + + ActionListener errorVerificationListener = ActionListener.wrap((model) -> { + if (throwOnUnknown) { + fail("Model verification should fail when throwOnUnknown is true"); + } + }, (e) -> { + if (throwOnUnknown) { + assertThat( + e.getMessage(), + containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") + ); + } else { + fail("Model verification should not fail when throwOnUnknown is false"); + } + }); + if (throwOnUnknown == false) { + var parsed = service.parsePersistedConfigWithSecrets( + "foo", + TaskType.SPARSE_EMBEDDING, + settings, + Collections.emptyMap() + ); + } else { + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); + } + } + + { + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.MODEL_ID, + ".elser_model_2", + "foo", + "bar" + ) + ) + ); + settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); + + ActionListener errorVerificationListener = ActionListener.wrap((model) -> { + if (throwOnUnknown) { + fail("Model verification should fail when throwOnUnknown is true"); + } + }, (e) -> { + if (throwOnUnknown) { + assertThat( + e.getMessage(), + containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") + ); + } else { + fail("Model verification should not fail when throwOnUnknown is false"); + } + }); + if (throwOnUnknown == false) { + var parsed = service.parsePersistedConfigWithSecrets( + "foo", + TaskType.SPARSE_EMBEDDING, + settings, + Collections.emptyMap() + ); + } else { + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); + } + } + } + } + + public void testParseRequestConfig_DefaultModel() { + var service = createService(mock(Client.class)); + { + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4 + ) + ) + ); + + ActionListener modelActionListener = ActionListener.wrap((model) -> { + assertEquals( + ".elser_model_2", + ((org.elasticsearch.xpack.inference.services.elser.ElserInternalModel) model).getServiceSettings().modelId() + ); + }, (e) -> { fail("Model verification should not fail"); }); + + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelActionListener); + } + { + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4 + ) + ) + ); + + ActionListener modelActionListener = ActionListener.wrap((model) -> { + assertEquals( + ".elser_model_2_linux-x86_64", + ((org.elasticsearch.xpack.inference.services.elser.ElserInternalModel) model).getServiceSettings().modelId() + ); + }, (e) -> { fail("Model verification should not fail"); }); + + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of("linux-x86_64"), modelActionListener); + } + } + + @SuppressWarnings("unchecked") + public void testChunkInfer() { + var mlTrainedModelResults = new ArrayList(); + mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); + mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); + mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); + var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); + + ThreadPool threadpool = new TestThreadPool("test"); + Client client = mock(Client.class); + when(client.threadPool()).thenReturn(threadpool); + doAnswer(invocationOnMock -> { + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(response); + return null; + }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); + + var model = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( + "foo", + TaskType.SPARSE_EMBEDDING, + "elser", + new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 1, "elser", null), + new ElserMlNodeTaskSettings() + ); + var service = createService(client); + + var gotResults = new AtomicBoolean(); + var resultsListener = ActionListener.>wrap(chunkedResponse -> { + assertThat(chunkedResponse, hasSize(3)); + assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); + assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(0)).getChunks(), result1.getChunkedResults()); + assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); + assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(1)).getChunks(), result2.getChunkedResults()); + var result3 = (ErrorChunkedInferenceResults) chunkedResponse.get(2); + assertThat(result3.getException(), instanceOf(RuntimeException.class)); + assertThat(result3.getException().getMessage(), containsString("boom")); + gotResults.set(true); + }, ESTestCase::fail); + + service.chunkedInfer( + model, + null, + List.of("foo", "bar"), + Map.of(), + InputType.SEARCH, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.runAfter(resultsListener, () -> terminate(threadpool)) + ); + + if (gotResults.get() == false) { + terminate(threadpool); + } + assertTrue("Listener not called", gotResults.get()); + } + + @SuppressWarnings("unchecked") + public void testChunkInferSetsTokenization() { + var expectedSpan = new AtomicInteger(); + var expectedWindowSize = new AtomicReference(); + + ThreadPool threadpool = new TestThreadPool("test"); + Client client = mock(Client.class); + try { + when(client.threadPool()).thenReturn(threadpool); + doAnswer(invocationOnMock -> { + var request = (InferTrainedModelDeploymentAction.Request) invocationOnMock.getArguments()[1]; + assertThat(request.getUpdate(), instanceOf(TokenizationConfigUpdate.class)); + var update = (TokenizationConfigUpdate) request.getUpdate(); + assertEquals(update.getSpanSettings().span(), expectedSpan.get()); + assertEquals(update.getSpanSettings().maxSequenceLength(), expectedWindowSize.get()); + return null; + }).when(client) + .execute( + same(InferTrainedModelDeploymentAction.INSTANCE), + any(InferTrainedModelDeploymentAction.Request.class), + any(ActionListener.class) + ); + + var model = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( + "foo", + TaskType.SPARSE_EMBEDDING, + "elser", + new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 1, "elser", null), + new ElserMlNodeTaskSettings() + ); + var service = createService(client); + + expectedSpan.set(-1); + expectedWindowSize.set(null); + service.chunkedInfer( + model, + List.of("foo", "bar"), + Map.of(), + InputType.SEARCH, + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) + ); + + expectedSpan.set(-1); + expectedWindowSize.set(256); + service.chunkedInfer( + model, + List.of("foo", "bar"), + Map.of(), + InputType.SEARCH, + new ChunkingOptions(256, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) + ); + } finally { + terminate(threadpool); + } + } + + @SuppressWarnings("unchecked") + public void testPutModel() { + var client = mock(Client.class); + ArgumentCaptor argument = ArgumentCaptor.forClass(PutTrainedModelAction.Request.class); + + doAnswer(invocation -> { + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(new PutTrainedModelAction.Response(mock(TrainedModelConfig.class))); + return null; + }).when(client).execute(Mockito.same(PutTrainedModelAction.INSTANCE), argument.capture(), any()); + + when(client.threadPool()).thenReturn(threadPool); + + var service = createService(client); + + var model = new ElserInternalModel( + "my-elser", + TaskType.SPARSE_EMBEDDING, + "elser", + new ElserInternalServiceSettings(1, 1, ".elser_model_2", null), + ElserMlNodeTaskSettings.DEFAULT + ); + + service.putModel(model, new ActionListener<>() { + @Override + public void onResponse(Boolean success) { + assertTrue(success); + } + + @Override + public void onFailure(Exception e) { + fail(e); + } + }); + + var putConfig = argument.getValue().getTrainedModelConfig(); + assertEquals("text_field", putConfig.getInput().getFieldNames().get(0)); + } + + private ElserInternalService createService(Client client) { + var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); + return new ElserInternalService(context); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java new file mode 100644 index 0000000000000..792c35aaf8a89 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; + +public class ElserMlNodeTaskSettingsTests extends AbstractWireSerializingTestCase { + + public static ElserMlNodeTaskSettings createRandom() { + return ElserMlNodeTaskSettings.DEFAULT; // no options to randomise + } + + @Override + protected Writeable.Reader instanceReader() { + return ElserMlNodeTaskSettings::new; + } + + @Override + protected ElserMlNodeTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected ElserMlNodeTaskSettings mutateInstance(ElserMlNodeTaskSettings instance) { + return null; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java new file mode 100644 index 0000000000000..1c630acc32cc3 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.elser.ElserModels; + +public class ElserModelsTests extends ESTestCase { + + public static String randomElserModel() { + return randomFrom(org.elasticsearch.xpack.inference.services.elser.ElserModels.VALID_ELSER_MODEL_IDS); + } + + public void testIsValidModel() { + assertTrue(org.elasticsearch.xpack.inference.services.elser.ElserModels.isValidModel(randomElserModel())); + } + + public void testIsValidEisModel() { + assertTrue( + org.elasticsearch.xpack.inference.services.elser.ElserModels.isValidEisModel( + org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL + ) + ); + } + + public void testIsInvalidModel() { + assertFalse(org.elasticsearch.xpack.inference.services.elser.ElserModels.isValidModel("invalid")); + } + + public void testIsInvalidEisModel() { + assertFalse(org.elasticsearch.xpack.inference.services.elser.ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL_LINUX_X86)); + } +} From 103e5f6aa61263b1ed06442b8ea97d780e6759b0 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 19 Sep 2024 14:44:55 -0400 Subject: [PATCH 03/22] Add deprecation log message for elser service --- .../action/TransportPutInferenceModelAction.java | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index ec54294432fe8..f060ec0f06f46 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -20,6 +20,8 @@ import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.logging.DeprecationCategory; +import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; @@ -35,6 +37,8 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xcontent.XContentLocation; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; @@ -45,6 +49,8 @@ import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; +import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import java.io.IOException; import java.util.Map; @@ -57,6 +63,7 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction< PutInferenceModelAction.Response> { private static final Logger logger = LogManager.getLogger(TransportPutInferenceModelAction.class); + private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(PutInferenceModelAction.class); private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; @@ -113,6 +120,13 @@ protected void masterOperation( ) ); return; + } else if (serviceName.equals(ElserInternalService.NAME)) { + DEPRECATION_LOGGER.warn( + DeprecationCategory.API, + "The [{}] service is deprecated and will be removed in a future release. Use the [{}] service instead.", + ElserInternalService.NAME, + ElasticsearchInternalService.NAME + ); } var service = serviceRegistry.getService(serviceName); From 6951cb27ab1b6749ffe860f86b7d527a3dd0cd79 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 19 Sep 2024 15:57:54 -0400 Subject: [PATCH 04/22] improve deprecation warning --- .../integration/ModelRegistryIT.java | 4 +-- .../TransportPutInferenceModelAction.java | 31 +++++++++++++------ .../ElasticsearchInternalService.java | 3 +- .../ElasticsearchInternalServiceSettings.java | 4 +-- 4 files changed, 28 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 1713290a139c7..26fdc364e504f 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -27,11 +27,11 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceTests; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; +import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import org.junit.Before; import java.io.IOException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index f060ec0f06f46..4fa6a994b4a62 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -37,8 +37,6 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xcontent.XContent; -import org.elasticsearch.xcontent.XContentLocation; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; @@ -57,6 +55,9 @@ import java.util.Set; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.services.elasticsearch.BaseElasticsearchInternalService.selectDefaultModelVariantBasedOnClusterArchitecture; +import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL; +import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL_LINUX_X86; public class TransportPutInferenceModelAction extends TransportMasterNodeAction< PutInferenceModelAction.Request, @@ -120,13 +121,6 @@ protected void masterOperation( ) ); return; - } else if (serviceName.equals(ElserInternalService.NAME)) { - DEPRECATION_LOGGER.warn( - DeprecationCategory.API, - "The [{}] service is deprecated and will be removed in a future release. Use the [{}] service instead.", - ElserInternalService.NAME, - ElasticsearchInternalService.NAME - ); } var service = serviceRegistry.getService(serviceName); @@ -174,6 +168,25 @@ protected void masterOperation( // Find the cluster platform as the service may need that // information when creating the model MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet(listener.delegateFailureAndWrap((delegate, architectures) -> { + + if (serviceName.equals(ElserInternalService.NAME)) { // TODO remove this block once the elser service is removed + String modelId = selectDefaultModelVariantBasedOnClusterArchitecture( + architectures, + ELSER_V2_MODEL_LINUX_X86, + ELSER_V2_MODEL + ); + + DEPRECATION_LOGGER.critical( + DeprecationCategory.API, + "inference_api_elser_service", + "The [{}] service is deprecated and will be removed in a future release. Use the [{}] service instead, with" + + " [model_id] set to [{}] in the [service_settings]", + ElserInternalService.NAME, + ElasticsearchInternalService.NAME, + modelId + ); + } + if (architectures.isEmpty() && clusterIsInElasticCloud(clusterService.getClusterSettings())) { parseAndStoreModel( service.get(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 4901ed41f2958..05c42af6d7106 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -237,7 +237,8 @@ private void elserCase( .equals(selectDefaultModelVariantBasedOnClusterArchitecture(platformArchitectures, ELSER_V2_MODEL_LINUX_X86, ELSER_V2_MODEL))) { if (esServiceSettingsBuilder.getModelId().equals(ELSER_V2_MODEL)) { logger.warn( - "The platform agnostic model [{}] was requested on Linux x86_64. It is recommended to use the optimized model instead [{}]", + "The platform agnostic model [{}] was requested on Linux x86_64. " + + "It is recommended to use the optimized model instead [{}]", ELSER_V2_MODEL, ELSER_V2_MODEL_LINUX_X86 ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index ff1088b6473b5..f8b5837ef387e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -27,7 +27,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; public class ElasticsearchInternalServiceSettings implements ServiceSettings { @@ -84,7 +83,8 @@ protected static ElasticsearchInternalServiceSettings.Builder fromMap( validationException ); - String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + // model id is optional as the ELSER service will default it. TODO make this a required field once the elser service is removed + String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); if (numAllocations == null && adaptiveAllocationsSettings == null) { validationException.addValidationError( From 99d9e84230b922b08a13d8c0d1914ec9fb56ef02 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 19 Sep 2024 16:05:29 -0400 Subject: [PATCH 05/22] change elasticsearch internal service elser case to use elser model --- .../elasticsearch/ElasticsearchInternalService.java | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 05c42af6d7106..ad3550dd8236b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -263,12 +263,7 @@ private void elserCase( throwIfNotEmptyMap(serviceSettingsMap, name()); modelListener.onResponse( - new MultilingualE5SmallModel( - inferenceEntityId, - taskType, - NAME, - new MultilingualE5SmallInternalServiceSettings(esServiceSettingsBuilder.build()) - ) + new ElserInternalModel(inferenceEntityId, taskType, NAME, new ElserInternalServiceSettings(esServiceSettingsBuilder.build())) ); } From f9da1bd7cdbbfa63f130b79fbb185dc66e683dde Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 19 Sep 2024 16:11:11 -0400 Subject: [PATCH 06/22] switch elasticsearch elser tests to use elasticsearch elser --- .../services/elser/ElserInternalModel.java | 15 +- .../ElserInternalServiceSettingsTests.java | 33 +- .../ElserInternalServiceTests.java | 548 ------------------ .../elasticsearch/ElserModelsTests.java | 2 +- .../elser/ElserMlNodeTaskSettingsTests.java | 33 -- 5 files changed, 17 insertions(+), 614 deletions(-) delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettingsTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java index bb668c314649d..6efdfe7685b9f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java @@ -10,6 +10,7 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -17,14 +18,8 @@ public class ElserInternalModel extends ElasticsearchInternalModel { - public ElserInternalModel( - String inferenceEntityId, - TaskType taskType, - String service, - ElserInternalServiceSettings serviceSettings, - ElserMlNodeTaskSettings taskSettings - ) { - super(inferenceEntityId, taskType, service, serviceSettings, taskSettings); + public ElserInternalModel(String inferenceEntityId, TaskType taskType, String service, ElserInternalServiceSettings serviceSettings) { + super(inferenceEntityId, taskType, service, serviceSettings); } @Override @@ -33,8 +28,8 @@ public ElserInternalServiceSettings getServiceSettings() { } @Override - public ElserMlNodeTaskSettings getTaskSettings() { - return (ElserMlNodeTaskSettings) super.getTaskSettings(); + public TaskSettings getTaskSettings() { + return super.getTaskSettings(); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java index abc1c848a3e0d..f4e97b2c2e5e0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettingsTests.java @@ -10,56 +10,45 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; import java.io.IOException; import java.util.HashSet; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModelsTests.randomElserModel; -public class ElserInternalServiceSettingsTests extends AbstractWireSerializingTestCase< - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings> { +public class ElserInternalServiceSettingsTests extends AbstractWireSerializingTestCase { - public static org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings createRandom() { - return new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings( - ElasticsearchInternalServiceSettingsTests.validInstance(randomElserModel()) - ); + public static ElserInternalServiceSettings createRandom() { + return new ElserInternalServiceSettings(ElasticsearchInternalServiceSettingsTests.validInstance(randomElserModel())); } public void testBwcWrite() throws IOException { { - var settings = new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings( - new ElasticsearchInternalServiceSettings(1, 1, ".elser_model_1", null) - ); + var settings = new ElserInternalServiceSettings(new ElasticsearchInternalServiceSettings(1, 1, ".elser_model_1", null)); var copy = copyInstance(settings, TransportVersions.V_8_12_0); assertEquals(settings, copy); } { - var settings = new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings( - new ElasticsearchInternalServiceSettings(1, 1, ".elser_model_1", null) - ); + var settings = new ElserInternalServiceSettings(new ElasticsearchInternalServiceSettings(1, 1, ".elser_model_1", null)); var copy = copyInstance(settings, TransportVersions.V_8_11_X); assertEquals(settings, copy); } } @Override - protected Writeable.Reader instanceReader() { - return org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings::new; + protected Writeable.Reader instanceReader() { + return ElserInternalServiceSettings::new; } @Override - protected org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings createTestInstance() { + protected ElserInternalServiceSettings createTestInstance() { return createRandom(); } @Override - protected org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings mutateInstance( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings instance - ) { + protected ElserInternalServiceSettings mutateInstance(ElserInternalServiceSettings instance) { return switch (randomIntBetween(0, 2)) { - case 0 -> new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings( + case 0 -> new ElserInternalServiceSettings( new ElasticsearchInternalServiceSettings( instance.getNumAllocations() == null ? 1 : instance.getNumAllocations() + 1, instance.getNumThreads(), @@ -67,7 +56,7 @@ protected org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceS null ) ); - case 1 -> new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings( + case 1 -> new ElserInternalServiceSettings( new ElasticsearchInternalServiceSettings( instance.getNumAllocations(), instance.getNumThreads() + 1, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java index 72849d3539343..e69de29bb2d1d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java @@ -1,548 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elasticsearch; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.ChunkingOptions; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceServiceExtension; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceChunkedTextExpansionResultsTests; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; -import org.junit.After; -import org.junit.Before; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; - -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class ElserInternalServiceTests extends ESTestCase { - - private static ThreadPool threadPool; - - @Before - public void setUpThreadPool() { - threadPool = new TestThreadPool("test"); - } - - @After - public void shutdownThreadPool() { - TestThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); - } - - public static Model randomModelConfig(String inferenceEntityId, TaskType taskType) { - return switch (taskType) { - case SPARSE_EMBEDDING -> new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( - inferenceEntityId, - taskType, - ElserInternalService.NAME, - ElserInternalServiceSettingsTests.createRandom(), - ElserMlNodeTaskSettingsTests.createRandom() - ); - default -> throw new IllegalArgumentException("task type " + taskType + " is not supported"); - }; - } - - public void testParseConfigStrict() { - var service = createService(mock(Client.class)); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4, - "model_id", - ".elser_model_1" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - - var expectedModel = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 4, ".elser_model_1", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var modelVerificationListener = getModelVerificationListener(expectedModel); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelVerificationListener); - - } - - public void testParseConfigLooseWithOldModelId() { - var service = createService(mock(Client.class)); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4, - "model_version", - ".elser_model_1" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - - var expectedModel = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 4, ".elser_model_1", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var realModel = service.parsePersistedConfig("foo", TaskType.SPARSE_EMBEDDING, settings); - - assertEquals(expectedModel, realModel); - - } - - private static ActionListener getModelVerificationListener( - org.elasticsearch.xpack.inference.services.elser.ElserInternalModel expectedModel - ) { - return ActionListener.wrap( - (model) -> { assertEquals(expectedModel, model); }, - (e) -> fail("Model verification should not fail " + e.getMessage()) - ); - } - - public void testParseConfigStrictWithNoTaskSettings() { - var service = createService(mock(Client.class)); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4 - ) - ) - ); - - var expectedModel = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var modelVerificationListener = getModelVerificationListener(expectedModel); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelVerificationListener); - - } - - public void testParseConfigStrictWithUnknownSettings() { - - var service = createService(mock(Client.class)); - - for (boolean throwOnUnknown : new boolean[] { true, false }) { - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - settings.put("foo", "bar"); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); - } - } - - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); - } - } - - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2", - "foo", - "bar" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); - } - } - } - } - - public void testParseRequestConfig_DefaultModel() { - var service = createService(mock(Client.class)); - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4 - ) - ) - ); - - ActionListener modelActionListener = ActionListener.wrap((model) -> { - assertEquals( - ".elser_model_2", - ((org.elasticsearch.xpack.inference.services.elser.ElserInternalModel) model).getServiceSettings().modelId() - ); - }, (e) -> { fail("Model verification should not fail"); }); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelActionListener); - } - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4 - ) - ) - ); - - ActionListener modelActionListener = ActionListener.wrap((model) -> { - assertEquals( - ".elser_model_2_linux-x86_64", - ((org.elasticsearch.xpack.inference.services.elser.ElserInternalModel) model).getServiceSettings().modelId() - ); - }, (e) -> { fail("Model verification should not fail"); }); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of("linux-x86_64"), modelActionListener); - } - } - - @SuppressWarnings("unchecked") - public void testChunkInfer() { - var mlTrainedModelResults = new ArrayList(); - mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); - mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); - mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); - var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); - - ThreadPool threadpool = new TestThreadPool("test"); - Client client = mock(Client.class); - when(client.threadPool()).thenReturn(threadpool); - doAnswer(invocationOnMock -> { - var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(response); - return null; - }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); - - var model = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - "elser", - new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 1, "elser", null), - new ElserMlNodeTaskSettings() - ); - var service = createService(client); - - var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { - assertThat(chunkedResponse, hasSize(3)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); - assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(0)).getChunks(), result1.getChunkedResults()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); - assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(1)).getChunks(), result2.getChunkedResults()); - var result3 = (ErrorChunkedInferenceResults) chunkedResponse.get(2); - assertThat(result3.getException(), instanceOf(RuntimeException.class)); - assertThat(result3.getException().getMessage(), containsString("boom")); - gotResults.set(true); - }, ESTestCase::fail); - - service.chunkedInfer( - model, - null, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - new ChunkingOptions(null, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadpool)) - ); - - if (gotResults.get() == false) { - terminate(threadpool); - } - assertTrue("Listener not called", gotResults.get()); - } - - @SuppressWarnings("unchecked") - public void testChunkInferSetsTokenization() { - var expectedSpan = new AtomicInteger(); - var expectedWindowSize = new AtomicReference(); - - ThreadPool threadpool = new TestThreadPool("test"); - Client client = mock(Client.class); - try { - when(client.threadPool()).thenReturn(threadpool); - doAnswer(invocationOnMock -> { - var request = (InferTrainedModelDeploymentAction.Request) invocationOnMock.getArguments()[1]; - assertThat(request.getUpdate(), instanceOf(TokenizationConfigUpdate.class)); - var update = (TokenizationConfigUpdate) request.getUpdate(); - assertEquals(update.getSpanSettings().span(), expectedSpan.get()); - assertEquals(update.getSpanSettings().maxSequenceLength(), expectedWindowSize.get()); - return null; - }).when(client) - .execute( - same(InferTrainedModelDeploymentAction.INSTANCE), - any(InferTrainedModelDeploymentAction.Request.class), - any(ActionListener.class) - ); - - var model = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - "elser", - new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 1, "elser", null), - new ElserMlNodeTaskSettings() - ); - var service = createService(client); - - expectedSpan.set(-1); - expectedWindowSize.set(null); - service.chunkedInfer( - model, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - null, - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) - ); - - expectedSpan.set(-1); - expectedWindowSize.set(256); - service.chunkedInfer( - model, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - new ChunkingOptions(256, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) - ); - } finally { - terminate(threadpool); - } - } - - @SuppressWarnings("unchecked") - public void testPutModel() { - var client = mock(Client.class); - ArgumentCaptor argument = ArgumentCaptor.forClass(PutTrainedModelAction.Request.class); - - doAnswer(invocation -> { - var listener = (ActionListener) invocation.getArguments()[2]; - listener.onResponse(new PutTrainedModelAction.Response(mock(TrainedModelConfig.class))); - return null; - }).when(client).execute(Mockito.same(PutTrainedModelAction.INSTANCE), argument.capture(), any()); - - when(client.threadPool()).thenReturn(threadPool); - - var service = createService(client); - - var model = new ElserInternalModel( - "my-elser", - TaskType.SPARSE_EMBEDDING, - "elser", - new ElserInternalServiceSettings(1, 1, ".elser_model_2", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - service.putModel(model, new ActionListener<>() { - @Override - public void onResponse(Boolean success) { - assertTrue(success); - } - - @Override - public void onFailure(Exception e) { - fail(e); - } - }); - - var putConfig = argument.getValue().getTrainedModelConfig(); - assertEquals("text_field", putConfig.getInput().getFieldNames().get(0)); - } - - private ElserInternalService createService(Client client) { - var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); - return new ElserInternalService(context); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java index 1c630acc32cc3..a6c16a92d386c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java @@ -8,7 +8,7 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; public class ElserModelsTests extends ESTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettingsTests.java deleted file mode 100644 index d55065a5f9b27..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettingsTests.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; - -public class ElserMlNodeTaskSettingsTests extends AbstractWireSerializingTestCase { - - public static ElserMlNodeTaskSettings createRandom() { - return ElserMlNodeTaskSettings.DEFAULT; // no options to randomise - } - - @Override - protected Writeable.Reader instanceReader() { - return ElserMlNodeTaskSettings::new; - } - - @Override - protected ElserMlNodeTaskSettings createTestInstance() { - return createRandom(); - } - - @Override - protected ElserMlNodeTaskSettings mutateInstance(ElserMlNodeTaskSettings instance) { - return null; - } -} From 2362d7e040330d436e141c5ad0e688479865a604 Mon Sep 17 00:00:00 2001 From: Max Hniebergall <137079448+maxhniebergall@users.noreply.github.com> Date: Thu, 19 Sep 2024 16:26:49 -0400 Subject: [PATCH 07/22] Update docs/changelog/113216.yaml --- docs/changelog/113216.yaml | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 docs/changelog/113216.yaml diff --git a/docs/changelog/113216.yaml b/docs/changelog/113216.yaml new file mode 100644 index 0000000000000..b559ba8bb8825 --- /dev/null +++ b/docs/changelog/113216.yaml @@ -0,0 +1,10 @@ +pr: 113216 +summary: "[Inference API] Deprecate elser service" +area: Machine Learning +type: deprecation +issues: [] +deprecation: + title: "[Inference API] Deprecate elser service" + area: Machine Learning + details: The `elser` service of the inference API will be removed in an upcoming release. Please use the elasticsearch service instead. + impact: In the current version there is no impact. In a future version, users of the `elser` service will no longer be able to use it, and will be required to use the `elasticsearch` service to access elser through the inference API. From 118519c250e25bbf49e6c3cf4524f3425991aa54 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Fri, 20 Sep 2024 10:45:01 -0400 Subject: [PATCH 08/22] alias elser service to elasticsearch --- .../elasticsearch/inference/InferenceServiceRegistry.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index 40b4e37f36509..20659098e4bdd 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -46,8 +46,14 @@ public Map getServices() { } public Optional getService(String serviceName) { + + if ("elser".equals(serviceName)) { // ElserService.NAME before removal + // here we are aliasing the elser service to use the elasticsearch service instead + return Optional.ofNullable(services.get("elasticsearch")); // ElasticsearchInternalService.NAME + } else { return Optional.ofNullable(services.get(serviceName)); } + } public List getNamedWriteables() { return namedWriteables; From 1e4ec6af34e4a4f70b46d819f5d43733a9d66340 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Fri, 20 Sep 2024 16:32:21 -0400 Subject: [PATCH 09/22] delete elser service package now that elasticsearch service supports it and has aliased it --- .../inference/InferenceServiceRegistry.java | 4 +- .../integration/ModelRegistryIT.java | 25 +- .../InferenceNamedWriteablesProvider.java | 4 +- .../xpack/inference/InferencePlugin.java | 2 - .../TransportPutInferenceModelAction.java | 13 +- ...InferenceServiceSparseEmbeddingsModel.java | 2 +- ...erviceSparseEmbeddingsServiceSettings.java | 2 +- .../BaseElasticsearchInternalService.java | 1 - .../ElasticsearchInternalService.java | 12 +- .../elasticsearch/ElserInternalModel.java | 2 - .../ElserInternalServiceSettings.java | 1 - .../ElserMlNodeTaskSettings.java | 2 +- .../services/elser/ElserInternalModel.java | 61 --- .../services/elser/ElserInternalService.java | 281 ---------- .../elser/ElserInternalServiceSettings.java | 71 --- .../inference/services/elser/ElserModels.java | 33 -- .../inference/ModelConfigurationsTests.java | 1 + ...enceServiceSparseEmbeddingsModelTests.java | 2 +- ...eSparseEmbeddingsServiceSettingsTests.java | 2 +- .../elastic/ElasticInferenceServiceTests.java | 3 +- ...ticsearchInternalServiceSettingsTests.java | 1 - .../ElserMlNodeTaskSettingsTests.java | 1 - .../elasticsearch/ElserModelsTests.java | 15 +- .../ElserInternalServiceSettingsTests.java | 84 --- .../elser/ElserInternalServiceTests.java | 516 ------------------ .../services/elser/ElserModelsTests.java | 33 -- 26 files changed, 53 insertions(+), 1121 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/{elser => elasticsearch}/ElserMlNodeTaskSettings.java (96%) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserModels.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index 20659098e4bdd..f1ce94173a550 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -51,8 +51,8 @@ public Optional getService(String serviceName) { // here we are aliasing the elser service to use the elasticsearch service instead return Optional.ofNullable(services.get("elasticsearch")); // ElasticsearchInternalService.NAME } else { - return Optional.ofNullable(services.get(serviceName)); - } + return Optional.ofNullable(services.get(serviceName)); + } } public List getNamedWriteables() { diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 26fdc364e504f..012ef10962c7c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -27,11 +27,10 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceTests; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettingsTests; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import org.junit.Before; import java.io.IOException; @@ -117,8 +116,10 @@ public void testGetModel() throws Exception { assertEquals(model.getConfigurations().getService(), modelHolder.get().service()); - var elserService = new ElserInternalService(new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class))); - ElserInternalModel roundTripModel = elserService.parsePersistedConfigWithSecrets( + var elserService = new ElasticsearchInternalService( + new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class)) + ); + ElserInternalModel roundTripModel = (ElserInternalModel) elserService.parsePersistedConfigWithSecrets( modelHolder.get().inferenceEntityId(), modelHolder.get().taskType(), modelHolder.get().settings(), @@ -274,7 +275,17 @@ public void testGetModelWithSecrets() throws InterruptedException { } private Model buildElserModelConfig(String inferenceEntityId, TaskType taskType) { - return ElserInternalServiceTests.randomModelConfig(inferenceEntityId, taskType); + return switch (taskType) { + case SPARSE_EMBEDDING -> new org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalModel( + inferenceEntityId, + taskType, + ElasticsearchInternalService.NAME, + ElserInternalServiceSettingsTests.createRandom(), + ElserMlNodeTaskSettingsTests.createRandom() + ); + default -> throw new IllegalArgumentException("task type " + taskType + " is not supported"); + }; + } protected void blockingCall(Consumer> function, AtomicReference response, AtomicReference error) @@ -297,7 +308,7 @@ private static Model buildModelWithUnknownField(String inferenceEntityId) { new ModelWithUnknownField( inferenceEntityId, TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, + ElasticsearchInternalService.NAME, ElserInternalServiceSettingsTests.createRandom(), ElserMlNodeTaskSettingsTests.createRandom() ) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 336626cd1db20..02bddb6076d69 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -64,9 +64,9 @@ import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettings; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 16bd0942c6c26..954661a04fade 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -86,7 +86,6 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; @@ -229,7 +228,6 @@ public void loadExtensions(ExtensionLoader loader) { public List getInferenceServiceFactories() { return List.of( - ElserInternalService::new, context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()), context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 4fa6a994b4a62..43c4a724f9571 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -48,7 +48,6 @@ import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; import java.io.IOException; import java.util.Map; @@ -56,8 +55,8 @@ import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.services.elasticsearch.BaseElasticsearchInternalService.selectDefaultModelVariantBasedOnClusterArchitecture; -import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL; -import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL_LINUX_X86; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL_LINUX_X86; public class TransportPutInferenceModelAction extends TransportMasterNodeAction< PutInferenceModelAction.Request, @@ -168,20 +167,20 @@ protected void masterOperation( // Find the cluster platform as the service may need that // information when creating the model MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet(listener.delegateFailureAndWrap((delegate, architectures) -> { - - if (serviceName.equals(ElserInternalService.NAME)) { // TODO remove this block once the elser service is removed + String ELSER_SERVICE_NAME = "elser"; + if (serviceName.equals(ELSER_SERVICE_NAME)) { String modelId = selectDefaultModelVariantBasedOnClusterArchitecture( architectures, ELSER_V2_MODEL_LINUX_X86, ELSER_V2_MODEL ); - DEPRECATION_LOGGER.critical( + DEPRECATION_LOGGER.warn( DeprecationCategory.API, "inference_api_elser_service", "The [{}] service is deprecated and will be removed in a future release. Use the [{}] service instead, with" + " [model_id] set to [{}] in the [service_settings]", - ElserInternalService.NAME, + ELSER_SERVICE_NAME, ElasticsearchInternalService.NAME, modelId ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java index 163e3dd654150..bbbae736dbeb9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java @@ -18,7 +18,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionVisitor; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import java.net.URI; import java.net.URISyntaxException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java index 15b89525f7915..bbda1bb716794 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java @@ -17,7 +17,7 @@ import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index 457416370e559..93d40c24f9da9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -28,7 +28,6 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; import java.io.IOException; import java.util.EnumSet; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index ad3550dd8236b..b23571c157137 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -56,9 +56,9 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL_LINUX_X86; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.VALID_ELSER_MODEL_IDS; -import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL; -import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL_LINUX_X86; public class ElasticsearchInternalService extends BaseElasticsearchInternalService { @@ -263,7 +263,13 @@ private void elserCase( throwIfNotEmptyMap(serviceSettingsMap, name()); modelListener.onResponse( - new ElserInternalModel(inferenceEntityId, taskType, NAME, new ElserInternalServiceSettings(esServiceSettingsBuilder.build())) + new ElserInternalModel( + inferenceEntityId, + taskType, + NAME, + new ElserInternalServiceSettings(esServiceSettingsBuilder.build()), + ElserMlNodeTaskSettings.DEFAULT + ) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java index 0389411d433a8..827eb178f7633 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java @@ -13,8 +13,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; public class ElserInternalModel extends ElasticsearchInternalModel { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java index f5199599fa73e..f7bcd95c8bd28 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java @@ -12,7 +12,6 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; import java.io.IOException; import java.util.Arrays; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettings.java similarity index 96% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettings.java index 9b9f6e41113e5..934edaa96a15c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettings.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.elser; +package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java deleted file mode 100644 index 6efdfe7685b9f..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalModel.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskSettings; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; - -public class ElserInternalModel extends ElasticsearchInternalModel { - - public ElserInternalModel(String inferenceEntityId, TaskType taskType, String service, ElserInternalServiceSettings serviceSettings) { - super(inferenceEntityId, taskType, service, serviceSettings); - } - - @Override - public ElserInternalServiceSettings getServiceSettings() { - return (ElserInternalServiceSettings) super.getServiceSettings(); - } - - @Override - public TaskSettings getTaskSettings() { - return super.getTaskSettings(); - } - - @Override - public ActionListener getCreateTrainedModelAssignmentActionListener( - Model model, - ActionListener listener - ) { - return new ActionListener<>() { - @Override - public void onResponse(CreateTrainedModelAssignmentAction.Response response) { - listener.onResponse(Boolean.TRUE); - } - - @Override - public void onFailure(Exception e) { - if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { - listener.onFailure( - new ResourceNotFoundException( - "Could not start the ELSER service as the ELSER model for this platform cannot be found." - + " ELSER needs to be downloaded before it can be started." - ) - ); - return; - } - listener.onFailure(e); - } - }; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java deleted file mode 100644 index 746cb6e89fad0..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java +++ /dev/null @@ -1,281 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - * - * this file has been contributed to by a Generative AI - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.ChunkingOptions; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceServiceExtension; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; -import org.elasticsearch.xpack.inference.services.ServiceUtils; -import org.elasticsearch.xpack.inference.services.elasticsearch.BaseElasticsearchInternalService; - -import java.util.ArrayList; -import java.util.EnumSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; -import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL; -import static org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL_LINUX_X86; - -public class ElserInternalService extends BaseElasticsearchInternalService { - - public static final String NAME = "elser"; - - private static final String OLD_MODEL_ID_FIELD_NAME = "model_version"; - - public ElserInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) { - super(context); - } - - @Override - protected EnumSet supportedTaskTypes() { - return EnumSet.of(TaskType.SPARSE_EMBEDDING); - } - - @Override - public void parseRequestConfig( - String inferenceEntityId, - TaskType taskType, - Map config, - Set modelArchitectures, - ActionListener parsedModelListener - ) { - try { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - var serviceSettingsBuilder = ElserInternalServiceSettings.fromRequestMap(serviceSettingsMap); - - if (serviceSettingsBuilder.getModelId() == null) { - serviceSettingsBuilder.setModelId( - selectDefaultModelVariantBasedOnClusterArchitecture(modelArchitectures, ELSER_V2_MODEL_LINUX_X86, ELSER_V2_MODEL) - ); - } - - Map taskSettingsMap; - // task settings are optional - if (config.containsKey(ModelConfigurations.TASK_SETTINGS)) { - taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - } else { - taskSettingsMap = Map.of(); - } - - var taskSettings = taskSettingsFromMap(taskType, taskSettingsMap); - - throwIfNotEmptyMap(config, NAME); - throwIfNotEmptyMap(serviceSettingsMap, NAME); - throwIfNotEmptyMap(taskSettingsMap, NAME); - - parsedModelListener.onResponse( - new ElserInternalModel( - inferenceEntityId, - taskType, - NAME, - new ElserInternalServiceSettings(serviceSettingsBuilder.build()), - taskSettings - ) - ); - } catch (Exception e) { - parsedModelListener.onFailure(e); - } - } - - @Override - public ElserInternalModel parsePersistedConfigWithSecrets( - String inferenceEntityId, - TaskType taskType, - Map config, - Map secrets - ) { - return parsePersistedConfig(inferenceEntityId, taskType, config); - } - - @Override - public ElserInternalModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { - Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); - - // Change from old model_version field name to new model_id field name as of - // TransportVersions.ML_TEXT_EMBEDDING_INFERENCE_SERVICE_ADDED - if (serviceSettingsMap.containsKey(OLD_MODEL_ID_FIELD_NAME)) { - String modelId = ServiceUtils.removeAsType(serviceSettingsMap, OLD_MODEL_ID_FIELD_NAME, String.class); - serviceSettingsMap.put(ElserInternalServiceSettings.MODEL_ID, modelId); - } - - var serviceSettings = ElserInternalServiceSettings.fromPersistedMap(serviceSettingsMap); - - Map taskSettingsMap; - // task settings are optional - if (config.containsKey(ModelConfigurations.TASK_SETTINGS)) { - taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); - } else { - taskSettingsMap = Map.of(); - } - - var taskSettings = taskSettingsFromMap(taskType, taskSettingsMap); - - return new ElserInternalModel(inferenceEntityId, taskType, NAME, new ElserInternalServiceSettings(serviceSettings), taskSettings); - } - - @Override - public void infer( - Model model, - @Nullable String query, - List inputs, - boolean stream, - Map taskSettings, - InputType inputType, - TimeValue timeout, - ActionListener listener - ) { - // No task settings to override with requestTaskSettings - - try { - checkCompatibleTaskType(model.getConfigurations().getTaskType()); - } catch (Exception e) { - listener.onFailure(e); - return; - } - - var request = buildInferenceRequest( - model.getConfigurations().getInferenceEntityId(), - TextExpansionConfigUpdate.EMPTY_UPDATE, - inputs, - inputType, - timeout, - false // chunk - ); - - client.execute( - InferModelAction.INSTANCE, - request, - listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(SparseEmbeddingResults.of(inferenceResult.getInferenceResults())) - ) - ); - } - - public void chunkedInfer( - Model model, - List input, - Map taskSettings, - InputType inputType, - @Nullable ChunkingOptions chunkingOptions, - TimeValue timeout, - ActionListener> listener - ) { - chunkedInfer(model, null, input, taskSettings, inputType, chunkingOptions, timeout, listener); - } - - @Override - public void chunkedInfer( - Model model, - @Nullable String query, - List inputs, - Map taskSettings, - InputType inputType, - @Nullable ChunkingOptions chunkingOptions, - TimeValue timeout, - ActionListener> listener - ) { - try { - checkCompatibleTaskType(model.getConfigurations().getTaskType()); - } catch (Exception e) { - listener.onFailure(e); - return; - } - - var configUpdate = chunkingOptions != null - ? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span()) - : new TokenizationConfigUpdate(null, null); - - var request = buildInferenceRequest( - model.getConfigurations().getInferenceEntityId(), - configUpdate, - inputs, - inputType, - timeout, - true // chunk - ); - - client.execute( - InferModelAction.INSTANCE, - request, - listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse(translateChunkedResults(inferenceResult.getInferenceResults())) - ) - ); - } - - private void checkCompatibleTaskType(TaskType taskType) { - if (TaskType.SPARSE_EMBEDDING.isAnyOrSame(taskType) == false) { - throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); - } - } - - private static ElserMlNodeTaskSettings taskSettingsFromMap(TaskType taskType, Map config) { - if (taskType != TaskType.SPARSE_EMBEDDING) { - throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); - } - - // no config options yet - return ElserMlNodeTaskSettings.DEFAULT; - } - - private List translateChunkedResults(List inferenceResults) { - var translated = new ArrayList(); - - for (var inferenceResult : inferenceResults) { - if (inferenceResult instanceof MlChunkedTextExpansionResults mlChunkedResult) { - translated.add(InferenceChunkedSparseEmbeddingResults.ofMlResult(mlChunkedResult)); - } else if (inferenceResult instanceof ErrorInferenceResults error) { - translated.add(new ErrorChunkedInferenceResults(error.getException())); - } else { - throw new ElasticsearchStatusException( - "Expected a chunked inference [{}] received [{}]", - RestStatus.INTERNAL_SERVER_ERROR, - MlChunkedTextExpansionResults.NAME, - inferenceResult.getWriteableName() - ); - } - } - return translated; - } - - @Override - public String name() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_8_12_0; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java deleted file mode 100644 index fcbabd5a88fc6..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; - -import java.io.IOException; -import java.util.Arrays; -import java.util.Map; - -public class ElserInternalServiceSettings extends ElasticsearchInternalServiceSettings { - - public static final String NAME = "elser_mlnode_service_settings"; - - public static ElasticsearchInternalServiceSettings.Builder fromRequestMap(Map map) { - ValidationException validationException = new ValidationException(); - var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException); - - String modelId = baseSettings.getModelId(); - if (modelId != null && ElserModels.isValidModel(modelId) == false) { - var ve = new ValidationException(); - ve.addValidationError( - "Unknown ELSER model ID [" + modelId + "]. Valid models are " + Arrays.toString(ElserModels.VALID_ELSER_MODEL_IDS.toArray()) - ); - throw ve; - } - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return baseSettings; - } - - public ElserInternalServiceSettings(ElasticsearchInternalServiceSettings other) { - super(other); - } - - public ElserInternalServiceSettings( - Integer numAllocations, - int numThreads, - String modelId, - AdaptiveAllocationsSettings adaptiveAllocationsSettings - ) { - this(new ElasticsearchInternalServiceSettings(numAllocations, numThreads, modelId, adaptiveAllocationsSettings)); - } - - public ElserInternalServiceSettings(StreamInput in) throws IOException { - super(in); - } - - @Override - public String getWriteableName() { - return ElserInternalServiceSettings.NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_8_11_X; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserModels.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserModels.java deleted file mode 100644 index af94d2813dd2c..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserModels.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import java.util.Set; - -public class ElserModels { - - public static final String ELSER_V1_MODEL = ".elser_model_1"; - // Default non platform specific v2 model - public static final String ELSER_V2_MODEL = ".elser_model_2"; - public static final String ELSER_V2_MODEL_LINUX_X86 = ".elser_model_2_linux-x86_64"; - - public static Set VALID_ELSER_MODEL_IDS = Set.of( - ElserModels.ELSER_V1_MODEL, - ElserModels.ELSER_V2_MODEL, - ElserModels.ELSER_V2_MODEL_LINUX_X86 - ); - - public static boolean isValidModel(String model) { - return VALID_ELSER_MODEL_IDS.contains(model); - } - - public static boolean isValidEisModel(String model) { - return ELSER_V2_MODEL.equals(model); - } - -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java index 1a55178db28b4..df3712ea344a2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java index af13ce7944685..c9f4234331221 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java @@ -11,7 +11,7 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; public class ElasticInferenceServiceSparseEmbeddingsModelTests extends ESTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java index 0d54944880fd5..1751e1c3be5e8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java @@ -16,7 +16,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; import java.io.IOException; import java.util.HashMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 0bbf2be7301d8..758cd61b1b8ce 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -36,7 +36,8 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.ServiceFields; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; +import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java index 41afef88d22c6..419db748d793d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettingsTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; import java.io.IOException; import java.util.HashMap; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java index 792c35aaf8a89..a7de3fe8b8fdc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserMlNodeTaskSettingsTests.java @@ -9,7 +9,6 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; public class ElserMlNodeTaskSettingsTests extends AbstractWireSerializingTestCase { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java index a6c16a92d386c..fa0148ac69df5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModelsTests.java @@ -8,31 +8,32 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; public class ElserModelsTests extends ESTestCase { public static String randomElserModel() { - return randomFrom(org.elasticsearch.xpack.inference.services.elser.ElserModels.VALID_ELSER_MODEL_IDS); + return randomFrom(org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.VALID_ELSER_MODEL_IDS); } public void testIsValidModel() { - assertTrue(org.elasticsearch.xpack.inference.services.elser.ElserModels.isValidModel(randomElserModel())); + assertTrue(org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidModel(randomElserModel())); } public void testIsValidEisModel() { assertTrue( - org.elasticsearch.xpack.inference.services.elser.ElserModels.isValidEisModel( - org.elasticsearch.xpack.inference.services.elser.ElserModels.ELSER_V2_MODEL + org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidEisModel( + org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL ) ); } public void testIsInvalidModel() { - assertFalse(org.elasticsearch.xpack.inference.services.elser.ElserModels.isValidModel("invalid")); + assertFalse(org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidModel("invalid")); } public void testIsInvalidEisModel() { - assertFalse(org.elasticsearch.xpack.inference.services.elser.ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL_LINUX_X86)); + assertFalse( + org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL_LINUX_X86) + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java deleted file mode 100644 index ffbdf1a5a6178..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettingsTests; - -import java.io.IOException; -import java.util.HashSet; - -import static org.elasticsearch.xpack.inference.services.elser.ElserModelsTests.randomElserModel; - -public class ElserInternalServiceSettingsTests extends AbstractWireSerializingTestCase { - - public static ElserInternalServiceSettings createRandom() { - return new ElserInternalServiceSettings(ElasticsearchInternalServiceSettingsTests.validInstance(randomElserModel())); - } - - public void testBwcWrite() throws IOException { - { - var settings = new ElserInternalServiceSettings(new ElasticsearchInternalServiceSettings(1, 1, ".elser_model_1", null)); - var copy = copyInstance(settings, TransportVersions.V_8_12_0); - assertEquals(settings, copy); - } - { - var settings = new ElserInternalServiceSettings(new ElasticsearchInternalServiceSettings(1, 1, ".elser_model_1", null)); - var copy = copyInstance(settings, TransportVersions.V_8_11_X); - assertEquals(settings, copy); - } - } - - @Override - protected Writeable.Reader instanceReader() { - return ElserInternalServiceSettings::new; - } - - @Override - protected ElserInternalServiceSettings createTestInstance() { - return createRandom(); - } - - @Override - protected ElserInternalServiceSettings mutateInstance(ElserInternalServiceSettings instance) { - return switch (randomIntBetween(0, 2)) { - case 0 -> new ElserInternalServiceSettings( - new ElasticsearchInternalServiceSettings( - instance.getNumAllocations() == null ? 1 : instance.getNumAllocations() + 1, - instance.getNumThreads(), - instance.modelId(), - null - ) - ); - case 1 -> new ElserInternalServiceSettings( - new ElasticsearchInternalServiceSettings( - instance.getNumAllocations(), - instance.getNumThreads() + 1, - instance.modelId(), - null - ) - ); - case 2 -> { - var versions = new HashSet<>(ElserModels.VALID_ELSER_MODEL_IDS); - versions.remove(instance.modelId()); - yield new ElserInternalServiceSettings( - new ElasticsearchInternalServiceSettings( - instance.getNumAllocations(), - instance.getNumThreads(), - versions.iterator().next(), - null - ) - ); - } - default -> throw new IllegalStateException(); - }; - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java deleted file mode 100644 index 85add1a0090c8..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java +++ /dev/null @@ -1,516 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - * - * this file was contributed to by a generative AI - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.ChunkingOptions; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceServiceExtension; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceChunkedTextExpansionResultsTests; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; -import org.junit.After; -import org.junit.Before; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; - -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class ElserInternalServiceTests extends ESTestCase { - - private static ThreadPool threadPool; - - @Before - public void setUpThreadPool() { - threadPool = new TestThreadPool("test"); - } - - @After - public void shutdownThreadPool() { - TestThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); - } - - public static Model randomModelConfig(String inferenceEntityId, TaskType taskType) { - return switch (taskType) { - case SPARSE_EMBEDDING -> new ElserInternalModel( - inferenceEntityId, - taskType, - ElserInternalService.NAME, - ElserInternalServiceSettingsTests.createRandom(), - ElserMlNodeTaskSettingsTests.createRandom() - ); - default -> throw new IllegalArgumentException("task type " + taskType + " is not supported"); - }; - } - - public void testParseConfigStrict() { - var service = createService(mock(Client.class)); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - "model_id", - ".elser_model_1" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - - var expectedModel = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ".elser_model_1", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var modelVerificationListener = getModelVerificationListener(expectedModel); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelVerificationListener); - - } - - public void testParseConfigLooseWithOldModelId() { - var service = createService(mock(Client.class)); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - "model_version", - ".elser_model_1" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - - var expectedModel = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ".elser_model_1", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var realModel = service.parsePersistedConfig("foo", TaskType.SPARSE_EMBEDDING, settings); - - assertEquals(expectedModel, realModel); - - } - - private static ActionListener getModelVerificationListener(ElserInternalModel expectedModel) { - return ActionListener.wrap( - (model) -> { assertEquals(expectedModel, model); }, - (e) -> fail("Model verification should not fail " + e.getMessage()) - ); - } - - public void testParseConfigStrictWithNoTaskSettings() { - var service = createService(mock(Client.class)); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>(Map.of(ElserInternalServiceSettings.NUM_ALLOCATIONS, 1, ElserInternalServiceSettings.NUM_THREADS, 4)) - ); - - var expectedModel = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var modelVerificationListener = getModelVerificationListener(expectedModel); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelVerificationListener); - - } - - public void testParseConfigStrictWithUnknownSettings() { - - var service = createService(mock(Client.class)); - - for (boolean throwOnUnknown : new boolean[] { true, false }) { - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - settings.put("foo", "bar"); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); - } - } - - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); - } - } - - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - ElserInternalServiceSettings.NUM_THREADS, - 4, - ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2", - "foo", - "bar" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); - } - } - } - } - - public void testParseRequestConfig_DefaultModel() { - var service = createService(mock(Client.class)); - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>(Map.of(ElserInternalServiceSettings.NUM_ALLOCATIONS, 1, ElserInternalServiceSettings.NUM_THREADS, 4)) - ); - - ActionListener modelActionListener = ActionListener.wrap((model) -> { - assertEquals(".elser_model_2", ((ElserInternalModel) model).getServiceSettings().modelId()); - }, (e) -> { fail("Model verification should not fail"); }); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelActionListener); - } - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>(Map.of(ElserInternalServiceSettings.NUM_ALLOCATIONS, 1, ElserInternalServiceSettings.NUM_THREADS, 4)) - ); - - ActionListener modelActionListener = ActionListener.wrap((model) -> { - assertEquals(".elser_model_2_linux-x86_64", ((ElserInternalModel) model).getServiceSettings().modelId()); - }, (e) -> { fail("Model verification should not fail"); }); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of("linux-x86_64"), modelActionListener); - } - } - - @SuppressWarnings("unchecked") - public void testChunkInfer() { - var mlTrainedModelResults = new ArrayList(); - mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); - mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); - mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); - var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); - - ThreadPool threadpool = new TestThreadPool("test"); - Client client = mock(Client.class); - when(client.threadPool()).thenReturn(threadpool); - doAnswer(invocationOnMock -> { - var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(response); - return null; - }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); - - var model = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - "elser", - new ElserInternalServiceSettings(1, 1, "elser", null), - new ElserMlNodeTaskSettings() - ); - var service = createService(client); - - var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { - assertThat(chunkedResponse, hasSize(3)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); - assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(0)).getChunks(), result1.getChunkedResults()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); - assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(1)).getChunks(), result2.getChunkedResults()); - var result3 = (ErrorChunkedInferenceResults) chunkedResponse.get(2); - assertThat(result3.getException(), instanceOf(RuntimeException.class)); - assertThat(result3.getException().getMessage(), containsString("boom")); - gotResults.set(true); - }, ESTestCase::fail); - - service.chunkedInfer( - model, - null, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - new ChunkingOptions(null, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadpool)) - ); - - if (gotResults.get() == false) { - terminate(threadpool); - } - assertTrue("Listener not called", gotResults.get()); - } - - @SuppressWarnings("unchecked") - public void testChunkInferSetsTokenization() { - var expectedSpan = new AtomicInteger(); - var expectedWindowSize = new AtomicReference(); - - ThreadPool threadpool = new TestThreadPool("test"); - Client client = mock(Client.class); - try { - when(client.threadPool()).thenReturn(threadpool); - doAnswer(invocationOnMock -> { - var request = (InferTrainedModelDeploymentAction.Request) invocationOnMock.getArguments()[1]; - assertThat(request.getUpdate(), instanceOf(TokenizationConfigUpdate.class)); - var update = (TokenizationConfigUpdate) request.getUpdate(); - assertEquals(update.getSpanSettings().span(), expectedSpan.get()); - assertEquals(update.getSpanSettings().maxSequenceLength(), expectedWindowSize.get()); - return null; - }).when(client) - .execute( - same(InferTrainedModelDeploymentAction.INSTANCE), - any(InferTrainedModelDeploymentAction.Request.class), - any(ActionListener.class) - ); - - var model = new ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - "elser", - new ElserInternalServiceSettings(1, 1, "elser", null), - new ElserMlNodeTaskSettings() - ); - var service = createService(client); - - expectedSpan.set(-1); - expectedWindowSize.set(null); - service.chunkedInfer( - model, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - null, - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) - ); - - expectedSpan.set(-1); - expectedWindowSize.set(256); - service.chunkedInfer( - model, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - new ChunkingOptions(256, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) - ); - } finally { - terminate(threadpool); - } - } - - @SuppressWarnings("unchecked") - public void testPutModel() { - var client = mock(Client.class); - ArgumentCaptor argument = ArgumentCaptor.forClass(PutTrainedModelAction.Request.class); - - doAnswer(invocation -> { - var listener = (ActionListener) invocation.getArguments()[2]; - listener.onResponse(new PutTrainedModelAction.Response(mock(TrainedModelConfig.class))); - return null; - }).when(client).execute(Mockito.same(PutTrainedModelAction.INSTANCE), argument.capture(), any()); - - when(client.threadPool()).thenReturn(threadPool); - - var service = createService(client); - - var model = new ElserInternalModel( - "my-elser", - TaskType.SPARSE_EMBEDDING, - "elser", - new ElserInternalServiceSettings(1, 1, ".elser_model_2", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - service.putModel(model, new ActionListener<>() { - @Override - public void onResponse(Boolean success) { - assertTrue(success); - } - - @Override - public void onFailure(Exception e) { - fail(e); - } - }); - - var putConfig = argument.getValue().getTrainedModelConfig(); - assertEquals("text_field", putConfig.getInput().getFieldNames().get(0)); - } - - private ElserInternalService createService(Client client) { - var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); - return new ElserInternalService(context); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java deleted file mode 100644 index f56e941dcc8c0..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserModelsTests.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elser; - -import org.elasticsearch.test.ESTestCase; - -public class ElserModelsTests extends ESTestCase { - - public static String randomElserModel() { - return randomFrom(ElserModels.VALID_ELSER_MODEL_IDS); - } - - public void testIsValidModel() { - assertTrue(ElserModels.isValidModel(randomElserModel())); - } - - public void testIsValidEisModel() { - assertTrue(ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL)); - } - - public void testIsInvalidModel() { - assertFalse(ElserModels.isValidModel("invalid")); - } - - public void testIsInvalidEisModel() { - assertFalse(ElserModels.isValidEisModel(ElserModels.ELSER_V2_MODEL_LINUX_X86)); - } -} From 9352a30b33745a0dbf0a129e7e32cdbb1f230cb7 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 23 Sep 2024 10:23:37 -0400 Subject: [PATCH 10/22] Add deprecation warning to infer API for elser --- .../action/TransportInferenceAction.java | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 4186b281a35b5..2c65067d75c9e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -11,6 +11,8 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.common.logging.DeprecationCategory; +import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.inference.InferenceService; @@ -23,8 +25,10 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.Set; @@ -42,6 +46,7 @@ public class TransportInferenceAction extends HandledTransportAction Date: Mon, 23 Sep 2024 12:59:41 -0400 Subject: [PATCH 11/22] Fix accidentally introduced NPE and retain BWC support for null model ID (with deprecation message) --- .../action/TransportInferenceAction.java | 3 +- .../ElasticsearchInternalService.java | 30 ++++++++++++++----- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 2c65067d75c9e..9ed320cf43c35 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -25,7 +25,6 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; @@ -46,7 +45,7 @@ public class TransportInferenceAction extends HandledTransportAction modelListener ) { var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap); + final String defaultModelId = selectDefaultModelVariantBasedOnClusterArchitecture( + platformArchitectures, + ELSER_V2_MODEL_LINUX_X86, + ELSER_V2_MODEL + ); + if (false == defaultModelId.equals(esServiceSettingsBuilder.getModelId())) { - if (false == esServiceSettingsBuilder.getModelId() - .equals(selectDefaultModelVariantBasedOnClusterArchitecture(platformArchitectures, ELSER_V2_MODEL_LINUX_X86, ELSER_V2_MODEL))) { - if (esServiceSettingsBuilder.getModelId().equals(ELSER_V2_MODEL)) { + if (esServiceSettingsBuilder.getModelId() == null) { + // TODO remove this case once we remove the option to not pass model ID + esServiceSettingsBuilder.setModelId(defaultModelId); + } else if (esServiceSettingsBuilder.getModelId().equals(ELSER_V2_MODEL)) { logger.warn( "The platform agnostic model [{}] was requested on Linux x86_64. " + "It is recommended to use the optimized model instead [{}]", From 18b1ed09bcc64b6765bbfcb9be17c07ad807cd59 Mon Sep 17 00:00:00 2001 From: Max Hniebergall <137079448+maxhniebergall@users.noreply.github.com> Date: Mon, 23 Sep 2024 13:15:02 -0400 Subject: [PATCH 12/22] change "area" to "REST API" because "Machine Learning" isn't an option for deprecation --- docs/changelog/113216.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog/113216.yaml b/docs/changelog/113216.yaml index b559ba8bb8825..dec0b991fdacf 100644 --- a/docs/changelog/113216.yaml +++ b/docs/changelog/113216.yaml @@ -5,6 +5,6 @@ type: deprecation issues: [] deprecation: title: "[Inference API] Deprecate elser service" - area: Machine Learning + area: REST API details: The `elser` service of the inference API will be removed in an upcoming release. Please use the elasticsearch service instead. impact: In the current version there is no impact. In a future version, users of the `elser` service will no longer be able to use it, and will be required to use the `elasticsearch` service to access elser through the inference API. From 749edc61fcf460069107a12cb1933c4d5ab552a4 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 23 Sep 2024 16:04:27 -0400 Subject: [PATCH 13/22] change elser literals to static variable --- .../xpack/inference/action/TransportInferenceAction.java | 6 +++--- .../elasticsearch/ElasticsearchInternalService.java | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 9ed320cf43c35..afd0219157a5c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -34,6 +34,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; public class TransportInferenceAction extends HandledTransportAction { private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; @@ -79,13 +80,12 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe ); return; } - String ELSER_SERVICE_NAME = "elser"; - if (service.get().name().equals(ELSER_SERVICE_NAME)) { + if (service.get().name().equals(OLD_ELSER_SERVICE_NAME)) { DEPRECATION_LOGGER.warn( DeprecationCategory.API, "inference_api_elser_service", "The [{}] service is deprecated and will be removed in a future release. Use the [{}] service instead.", - ELSER_SERVICE_NAME, + OLD_ELSER_SERVICE_NAME, ElasticsearchInternalService.NAME ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 1362b674710d7..9c7ba4fe82823 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -64,6 +64,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalService { public static final String NAME = "elasticsearch"; + public static final String OLD_ELSER_SERVICE_NAME = "elser"; static final String MULTILINGUAL_E5_SMALL_MODEL_ID = ".multilingual-e5-small"; static final String MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 = ".multilingual-e5-small_linux-x86_64"; From 4811955b28f0803659e63ad192bfedd74e336d08 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 23 Sep 2024 16:05:36 -0400 Subject: [PATCH 14/22] change Put and Elasticsearch Internal service to pass the service name if it is elser or elasticsearch this will allow the elasticsearch service to maintain BWC for null model IDs if the service was elser. --- .../TransportPutInferenceModelAction.java | 14 +++++++---- .../ElasticsearchInternalService.java | 24 ++++++++++++------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 43c4a724f9571..654e69453dd76 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -55,6 +55,7 @@ import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.services.elasticsearch.BaseElasticsearchInternalService.selectDefaultModelVariantBasedOnClusterArchitecture; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL_LINUX_X86; @@ -167,8 +168,7 @@ protected void masterOperation( // Find the cluster platform as the service may need that // information when creating the model MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet(listener.delegateFailureAndWrap((delegate, architectures) -> { - String ELSER_SERVICE_NAME = "elser"; - if (serviceName.equals(ELSER_SERVICE_NAME)) { + if (serviceName.equals(OLD_ELSER_SERVICE_NAME)) { String modelId = selectDefaultModelVariantBasedOnClusterArchitecture( architectures, ELSER_V2_MODEL_LINUX_X86, @@ -180,7 +180,7 @@ protected void masterOperation( "inference_api_elser_service", "The [{}] service is deprecated and will be removed in a future release. Use the [{}] service instead, with" + " [model_id] set to [{}] in the [service_settings]", - ELSER_SERVICE_NAME, + OLD_ELSER_SERVICE_NAME, ElasticsearchInternalService.NAME, modelId ); @@ -249,8 +249,14 @@ private void parseAndStoreModel( } }); - service.parseRequestConfig(inferenceEntityId, taskType, config, platformArchitectures, parsedModelListener); + { // required for BWC of elser service in elasticsearch service + Set localServices = Set.of(ElasticsearchInternalService.NAME, OLD_ELSER_SERVICE_NAME); + if (localServices.contains(service.name())) { + config.put(ModelConfigurations.SERVICE, service.name()); + } + } + service.parseRequestConfig(inferenceEntityId, taskType, config, platformArchitectures, parsedModelListener); } private void putAndStartModel(InferenceService service, Model model, ActionListener finalListener) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 9c7ba4fe82823..eda62edcdf077 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -13,6 +13,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.core.Nullable; @@ -96,20 +97,25 @@ public void parseRequestConfig( try { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMap(config, ModelConfigurations.TASK_SETTINGS); + String serviceName = (String) config.remove(ModelConfigurations.SERVICE); // required for elser service in elasticsearch service throwIfNotEmptyMap(config, name()); String modelId = (String) serviceSettingsMap.get(ElasticsearchInternalServiceSettings.MODEL_ID); if (modelId == null) { - // TODO complete deprecation of null model ID - // throw new ValidationException().addValidationError("Error parsing request config, model id is missing"); - DEPRECATION_LOGGER.critical( - DeprecationCategory.API, - "inference_api_null_model_id_in_elasticsearch_service", - "Putting elasticsearch service inference endpoints (including elser service) without a model_id field is" - + " deprecated and will be removed in a future release. Please specify a model_id field." - ); - elserCase(inferenceEntityId, taskType, config, platformArchitectures, serviceSettingsMap, modelListener); + if (OLD_ELSER_SERVICE_NAME.equals(serviceName)) { + // TODO complete deprecation of null model ID + // throw new ValidationException().addValidationError("Error parsing request config, model id is missing"); + DEPRECATION_LOGGER.critical( + DeprecationCategory.API, + "inference_api_null_model_id_in_elasticsearch_service", + "Putting elasticsearch service inference endpoints (including elser service) without a model_id field is" + + " deprecated and will be removed in a future release. Please specify a model_id field." + ); + elserCase(inferenceEntityId, taskType, config, platformArchitectures, serviceSettingsMap, modelListener); + } else { + throw new ValidationException().addValidationError("Error parsing request config, model id is missing"); + } } else if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) { e5Case(inferenceEntityId, taskType, config, platformArchitectures, serviceSettingsMap, modelListener); } else if (VALID_ELSER_MODEL_IDS.contains(modelId)) { From 8b88ece81e756193b69148fedda2417b1e4e1cc8 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 23 Sep 2024 16:05:54 -0400 Subject: [PATCH 15/22] fix up tests to match new elasticsearch service semantics regarding elser. --- .../ElasticsearchInternalServiceTests.java | 176 ++++++++++++++++-- 1 file changed, 157 insertions(+), 19 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 257616033f080..4aa002c5b3ae5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -65,6 +65,8 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -93,9 +95,11 @@ public void shutdownThreadPool() { } public void testParseRequestConfig() { + // Null model variant var service = createService(mock(Client.class)); - var settings = new HashMap(); - settings.put( + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, ElasticsearchInternalService.NAME); + config.put( ModelConfigurations.SERVICE_SETTINGS, new HashMap<>( Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4) @@ -108,15 +112,16 @@ public void testParseRequestConfig() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, config, Set.of("not-linux"), modelListener); } public void testParseRequestConfig_Misconfigured() { - // Null model variant + // Non-existent model variant { var service = createService(mock(Client.class)); - var settings = new HashMap(); - settings.put( + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, ElasticsearchInternalService.NAME); + config.put( ModelConfigurations.SERVICE_SETTINGS, new HashMap<>( Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4) @@ -129,20 +134,21 @@ public void testParseRequestConfig_Misconfigured() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, config, Set.of(), modelListener); } // Invalid config map { var service = createService(mock(Client.class)); - var settings = new HashMap(); - settings.put( + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, ElasticsearchInternalService.NAME); + config.put( ModelConfigurations.SERVICE_SETTINGS, new HashMap<>( Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4) ) ); - settings.put("not_a_valid_config_setting", randomAlphaOfLength(10)); + config.put("not_a_valid_config_setting", randomAlphaOfLength(10)); ActionListener modelListener = ActionListener.wrap( model -> fail("Model parsing should have failed"), @@ -150,15 +156,16 @@ public void testParseRequestConfig_Misconfigured() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, Set.of(), modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, config, Set.of(), modelListener); } } public void testParseRequestConfig_E5() { { var service = createService(mock(Client.class)); - var settings = new HashMap(); - settings.put( + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, ElasticsearchInternalService.NAME); + config.put( ModelConfigurations.SERVICE_SETTINGS, new HashMap<>( Map.of( @@ -182,17 +189,18 @@ public void testParseRequestConfig_E5() { service.parseRequestConfig( randomInferenceEntityId, TaskType.TEXT_EMBEDDING, - settings, + config, Set.of(), - getModelVerificationActionListener(e5ServiceSettings) + getE5ModelVerificationActionListener(e5ServiceSettings) ); } // Invalid service settings { var service = createService(mock(Client.class)); - var settings = new HashMap(); - settings.put( + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, ElasticsearchInternalService.NAME); + config.put( ModelConfigurations.SERVICE_SETTINGS, new HashMap<>( Map.of( @@ -213,7 +221,96 @@ public void testParseRequestConfig_E5() { e -> assertThat(e, instanceOf(ElasticsearchStatusException.class)) ); - service.parseRequestConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings, Set.of(), modelListener); + service.parseRequestConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, config, Set.of(), modelListener); + } + } + + public void testParseRequestConfig_elser() { + // General happy case + { + var service = createService(mock(Client.class)); + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, OLD_ELSER_SERVICE_NAME); + config.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, + 1, + ElasticsearchInternalServiceSettings.NUM_THREADS, + 4, + ElasticsearchInternalServiceSettings.MODEL_ID, + ElserModels.ELSER_V2_MODEL + ) + ) + ); + + var elserServiceSettings = new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null); + + service.parseRequestConfig( + randomInferenceEntityId, + TaskType.SPARSE_EMBEDDING, + config, + Set.of(), + getElserModelVerificationActionListener(elserServiceSettings, null) + ); + } + + // null model ID returns elser model for the provided platform (not linux) + { + var service = createService(mock(Client.class)); + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, OLD_ELSER_SERVICE_NAME); + config.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of(ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, 1, ElasticsearchInternalServiceSettings.NUM_THREADS, 4) + ) + ); + + var elserServiceSettings = new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null); + + service.parseRequestConfig( + randomInferenceEntityId, + TaskType.SPARSE_EMBEDDING, + config, + Set.of("not-linux"), + getElserModelVerificationActionListener( + elserServiceSettings, + new String[] { + "Putting elasticsearch service inference endpoints (including elser service) without a model_id field is deprecated" + + " and will be removed in a future release. Please specify a model_id field." } + ) + ); + } + + // Invalid service settings + { + var service = createService(mock(Client.class)); + var config = new HashMap(); + config.put(ModelConfigurations.SERVICE, OLD_ELSER_SERVICE_NAME); + config.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, + 1, + ElasticsearchInternalServiceSettings.NUM_THREADS, + 4, + ElasticsearchInternalServiceSettings.MODEL_ID, + ElserModels.ELSER_V2_MODEL, + "not_a_valid_service_setting", + randomAlphaOfLength(10) + ) + ) + ); + + ActionListener modelListener = ActionListener.wrap( + model -> fail("Model parsing should have failed"), + e -> assertThat(e, instanceOf(ElasticsearchStatusException.class)) + ); + + service.parseRequestConfig(randomInferenceEntityId, TaskType.SPARSE_EMBEDDING, config, Set.of(), modelListener); } } @@ -344,7 +441,7 @@ public void testParseRequestConfig_SparseEmbedding() { service.parseRequestConfig(randomInferenceEntityId, TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelListener); } - private ActionListener getModelVerificationActionListener(MultilingualE5SmallInternalServiceSettings e5ServiceSettings) { + private ActionListener getE5ModelVerificationActionListener(MultilingualE5SmallInternalServiceSettings e5ServiceSettings) { return ActionListener.wrap(model -> { assertEquals( new MultilingualE5SmallModel( @@ -358,6 +455,47 @@ private ActionListener getModelVerificationActionListener(MultilingualE5S }, e -> { fail("Model parsing failed " + e.getMessage()); }); } + private ActionListener getElserModelVerificationActionListener( + ElserInternalServiceSettings elserServiceSettings, + String[] warnings + ) { + if (warnings != null) { + + return ActionListener.wrap(model -> { + assertCriticalWarnings( + "Putting elasticsearch service inference endpoints (including elser service) without a model_id field" + + " is deprecated and will be removed in a future release. Please specify a model_id field." + ); + assertEquals( + new ElserInternalModel( + randomInferenceEntityId, + TaskType.SPARSE_EMBEDDING, + NAME, + elserServiceSettings, + ElserMlNodeTaskSettings.DEFAULT + ), + model + ); + }, e -> { fail("Model parsing failed " + e.getMessage()); }); + + } else { + + return ActionListener.wrap(model -> { + assertEquals( + new ElserInternalModel( + randomInferenceEntityId, + TaskType.SPARSE_EMBEDDING, + NAME, + elserServiceSettings, + ElserMlNodeTaskSettings.DEFAULT + ), + model + ); + }, e -> { fail("Model parsing failed " + e.getMessage()); }); + + } + } + public void testParsePersistedConfig() { // Null model variant From ed4c6f6a3c34a0520d594b743de7c84358b94b61 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 23 Sep 2024 16:48:05 -0400 Subject: [PATCH 16/22] Move passing of service name --- .../action/TransportPutInferenceModelAction.java | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index 654e69453dd76..0496cec7b33d8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -165,6 +165,10 @@ protected void masterOperation( } if (service.get().isInClusterService()) { + + // required for BWC of elser service in elasticsearch service TODO remove when elser service deprecated + requestAsMap.put(ModelConfigurations.SERVICE, serviceName); + // Find the cluster platform as the service may need that // information when creating the model MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet(listener.delegateFailureAndWrap((delegate, architectures) -> { @@ -249,13 +253,6 @@ private void parseAndStoreModel( } }); - { // required for BWC of elser service in elasticsearch service - Set localServices = Set.of(ElasticsearchInternalService.NAME, OLD_ELSER_SERVICE_NAME); - if (localServices.contains(service.name())) { - config.put(ModelConfigurations.SERVICE, service.name()); - } - } - service.parseRequestConfig(inferenceEntityId, taskType, config, platformArchitectures, parsedModelListener); } From b5d6a9f9b987cb0748967bfb6c2f89dab15b9340 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 23 Sep 2024 16:48:25 -0400 Subject: [PATCH 17/22] add persistence for elser models in elasticsearch --- .../elasticsearch/ElasticsearchInternalService.java | 11 +++++++++++ .../ElasticsearchInternalServiceSettings.java | 1 + .../inference/services/elasticsearch/ElserModels.java | 2 +- .../elastic/ElasticInferenceServiceTests.java | 1 - 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index eda62edcdf077..f9fb34f8be20b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -73,6 +73,9 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi MULTILINGUAL_E5_SMALL_MODEL_ID, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 ); + public static final Set ELSER_VALID_IDS = Set.of(ELSER_V2_MODEL_LINUX_X86, ELSER_V2_MODEL); + + private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class); @@ -336,6 +339,14 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M NAME, new MultilingualE5SmallInternalServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)) ); + } else if (ElserModels.isValidModel(modelId)) { + return new ElserInternalModel( + inferenceEntityId, + taskType, + NAME, + new ElserInternalServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)), + ElserMlNodeTaskSettings.DEFAULT + ); } else { return createCustomElandModel( inferenceEntityId, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index f8b5837ef387e..38ba49fedd418 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -27,6 +27,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; public class ElasticsearchInternalServiceSettings implements ServiceSettings { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java index 44ee0f394354e..37f528ea3a750 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserModels.java @@ -23,7 +23,7 @@ public class ElserModels { ); public static boolean isValidModel(String model) { - return VALID_ELSER_MODEL_IDS.contains(model); + return model != null && VALID_ELSER_MODEL_IDS.contains(model); } public static boolean isValidEisModel(String model) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 758cd61b1b8ce..805931e728b9e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -37,7 +37,6 @@ import org.elasticsearch.xpack.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; -import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; From 49c98638af42cd3b9f2a99b80db82ddbebb943fa Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 19 Sep 2024 14:09:46 -0400 Subject: [PATCH 18/22] copy elser service files into elasticsearch service --- .../ElasticsearchInternalService.java | 1 - .../ElserInternalServiceTests.java | 548 ++++++++++++++++++ 2 files changed, 548 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index f9fb34f8be20b..23a14da892ada 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -73,7 +73,6 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi MULTILINGUAL_E5_SMALL_MODEL_ID, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 ); - public static final Set ELSER_VALID_IDS = Set.of(ELSER_V2_MODEL_LINUX_X86, ELSER_V2_MODEL); private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java index e69de29bb2d1d..72849d3539343 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java @@ -0,0 +1,548 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elasticsearch; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceExtension; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceChunkedTextExpansionResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; +import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; +import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; +import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; +import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; +import org.elasticsearch.xpack.inference.services.elser.ElserModels; +import org.junit.After; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ElserInternalServiceTests extends ESTestCase { + + private static ThreadPool threadPool; + + @Before + public void setUpThreadPool() { + threadPool = new TestThreadPool("test"); + } + + @After + public void shutdownThreadPool() { + TestThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); + } + + public static Model randomModelConfig(String inferenceEntityId, TaskType taskType) { + return switch (taskType) { + case SPARSE_EMBEDDING -> new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( + inferenceEntityId, + taskType, + ElserInternalService.NAME, + ElserInternalServiceSettingsTests.createRandom(), + ElserMlNodeTaskSettingsTests.createRandom() + ); + default -> throw new IllegalArgumentException("task type " + taskType + " is not supported"); + }; + } + + public void testParseConfigStrict() { + var service = createService(mock(Client.class)); + + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4, + "model_id", + ".elser_model_1" + ) + ) + ); + settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); + + var expectedModel = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( + "foo", + TaskType.SPARSE_EMBEDDING, + ElserInternalService.NAME, + new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 4, ".elser_model_1", null), + ElserMlNodeTaskSettings.DEFAULT + ); + + var modelVerificationListener = getModelVerificationListener(expectedModel); + + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelVerificationListener); + + } + + public void testParseConfigLooseWithOldModelId() { + var service = createService(mock(Client.class)); + + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4, + "model_version", + ".elser_model_1" + ) + ) + ); + settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); + + var expectedModel = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( + "foo", + TaskType.SPARSE_EMBEDDING, + ElserInternalService.NAME, + new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 4, ".elser_model_1", null), + ElserMlNodeTaskSettings.DEFAULT + ); + + var realModel = service.parsePersistedConfig("foo", TaskType.SPARSE_EMBEDDING, settings); + + assertEquals(expectedModel, realModel); + + } + + private static ActionListener getModelVerificationListener( + org.elasticsearch.xpack.inference.services.elser.ElserInternalModel expectedModel + ) { + return ActionListener.wrap( + (model) -> { assertEquals(expectedModel, model); }, + (e) -> fail("Model verification should not fail " + e.getMessage()) + ); + } + + public void testParseConfigStrictWithNoTaskSettings() { + var service = createService(mock(Client.class)); + + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4 + ) + ) + ); + + var expectedModel = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( + "foo", + TaskType.SPARSE_EMBEDDING, + ElserInternalService.NAME, + new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null), + ElserMlNodeTaskSettings.DEFAULT + ); + + var modelVerificationListener = getModelVerificationListener(expectedModel); + + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelVerificationListener); + + } + + public void testParseConfigStrictWithUnknownSettings() { + + var service = createService(mock(Client.class)); + + for (boolean throwOnUnknown : new boolean[] { true, false }) { + { + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.MODEL_ID, + ".elser_model_2" + ) + ) + ); + settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); + settings.put("foo", "bar"); + + ActionListener errorVerificationListener = ActionListener.wrap((model) -> { + if (throwOnUnknown) { + fail("Model verification should fail when throwOnUnknown is true"); + } + }, (e) -> { + if (throwOnUnknown) { + assertThat( + e.getMessage(), + containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") + ); + } else { + fail("Model verification should not fail when throwOnUnknown is false"); + } + }); + + if (throwOnUnknown == false) { + var parsed = service.parsePersistedConfigWithSecrets( + "foo", + TaskType.SPARSE_EMBEDDING, + settings, + Collections.emptyMap() + ); + } else { + + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); + } + } + + { + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.MODEL_ID, + ".elser_model_2" + ) + ) + ); + settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); + + ActionListener errorVerificationListener = ActionListener.wrap((model) -> { + if (throwOnUnknown) { + fail("Model verification should fail when throwOnUnknown is true"); + } + }, (e) -> { + if (throwOnUnknown) { + assertThat( + e.getMessage(), + containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") + ); + } else { + fail("Model verification should not fail when throwOnUnknown is false"); + } + }); + if (throwOnUnknown == false) { + var parsed = service.parsePersistedConfigWithSecrets( + "foo", + TaskType.SPARSE_EMBEDDING, + settings, + Collections.emptyMap() + ); + } else { + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); + } + } + + { + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.MODEL_ID, + ".elser_model_2", + "foo", + "bar" + ) + ) + ); + settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); + + ActionListener errorVerificationListener = ActionListener.wrap((model) -> { + if (throwOnUnknown) { + fail("Model verification should fail when throwOnUnknown is true"); + } + }, (e) -> { + if (throwOnUnknown) { + assertThat( + e.getMessage(), + containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") + ); + } else { + fail("Model verification should not fail when throwOnUnknown is false"); + } + }); + if (throwOnUnknown == false) { + var parsed = service.parsePersistedConfigWithSecrets( + "foo", + TaskType.SPARSE_EMBEDDING, + settings, + Collections.emptyMap() + ); + } else { + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); + } + } + } + } + + public void testParseRequestConfig_DefaultModel() { + var service = createService(mock(Client.class)); + { + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4 + ) + ) + ); + + ActionListener modelActionListener = ActionListener.wrap((model) -> { + assertEquals( + ".elser_model_2", + ((org.elasticsearch.xpack.inference.services.elser.ElserInternalModel) model).getServiceSettings().modelId() + ); + }, (e) -> { fail("Model verification should not fail"); }); + + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelActionListener); + } + { + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, + 1, + org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, + 4 + ) + ) + ); + + ActionListener modelActionListener = ActionListener.wrap((model) -> { + assertEquals( + ".elser_model_2_linux-x86_64", + ((org.elasticsearch.xpack.inference.services.elser.ElserInternalModel) model).getServiceSettings().modelId() + ); + }, (e) -> { fail("Model verification should not fail"); }); + + service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of("linux-x86_64"), modelActionListener); + } + } + + @SuppressWarnings("unchecked") + public void testChunkInfer() { + var mlTrainedModelResults = new ArrayList(); + mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); + mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); + mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); + var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); + + ThreadPool threadpool = new TestThreadPool("test"); + Client client = mock(Client.class); + when(client.threadPool()).thenReturn(threadpool); + doAnswer(invocationOnMock -> { + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(response); + return null; + }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); + + var model = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( + "foo", + TaskType.SPARSE_EMBEDDING, + "elser", + new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 1, "elser", null), + new ElserMlNodeTaskSettings() + ); + var service = createService(client); + + var gotResults = new AtomicBoolean(); + var resultsListener = ActionListener.>wrap(chunkedResponse -> { + assertThat(chunkedResponse, hasSize(3)); + assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); + assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(0)).getChunks(), result1.getChunkedResults()); + assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); + var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); + assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(1)).getChunks(), result2.getChunkedResults()); + var result3 = (ErrorChunkedInferenceResults) chunkedResponse.get(2); + assertThat(result3.getException(), instanceOf(RuntimeException.class)); + assertThat(result3.getException().getMessage(), containsString("boom")); + gotResults.set(true); + }, ESTestCase::fail); + + service.chunkedInfer( + model, + null, + List.of("foo", "bar"), + Map.of(), + InputType.SEARCH, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.runAfter(resultsListener, () -> terminate(threadpool)) + ); + + if (gotResults.get() == false) { + terminate(threadpool); + } + assertTrue("Listener not called", gotResults.get()); + } + + @SuppressWarnings("unchecked") + public void testChunkInferSetsTokenization() { + var expectedSpan = new AtomicInteger(); + var expectedWindowSize = new AtomicReference(); + + ThreadPool threadpool = new TestThreadPool("test"); + Client client = mock(Client.class); + try { + when(client.threadPool()).thenReturn(threadpool); + doAnswer(invocationOnMock -> { + var request = (InferTrainedModelDeploymentAction.Request) invocationOnMock.getArguments()[1]; + assertThat(request.getUpdate(), instanceOf(TokenizationConfigUpdate.class)); + var update = (TokenizationConfigUpdate) request.getUpdate(); + assertEquals(update.getSpanSettings().span(), expectedSpan.get()); + assertEquals(update.getSpanSettings().maxSequenceLength(), expectedWindowSize.get()); + return null; + }).when(client) + .execute( + same(InferTrainedModelDeploymentAction.INSTANCE), + any(InferTrainedModelDeploymentAction.Request.class), + any(ActionListener.class) + ); + + var model = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( + "foo", + TaskType.SPARSE_EMBEDDING, + "elser", + new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 1, "elser", null), + new ElserMlNodeTaskSettings() + ); + var service = createService(client); + + expectedSpan.set(-1); + expectedWindowSize.set(null); + service.chunkedInfer( + model, + List.of("foo", "bar"), + Map.of(), + InputType.SEARCH, + null, + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) + ); + + expectedSpan.set(-1); + expectedWindowSize.set(256); + service.chunkedInfer( + model, + List.of("foo", "bar"), + Map.of(), + InputType.SEARCH, + new ChunkingOptions(256, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) + ); + } finally { + terminate(threadpool); + } + } + + @SuppressWarnings("unchecked") + public void testPutModel() { + var client = mock(Client.class); + ArgumentCaptor argument = ArgumentCaptor.forClass(PutTrainedModelAction.Request.class); + + doAnswer(invocation -> { + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse(new PutTrainedModelAction.Response(mock(TrainedModelConfig.class))); + return null; + }).when(client).execute(Mockito.same(PutTrainedModelAction.INSTANCE), argument.capture(), any()); + + when(client.threadPool()).thenReturn(threadPool); + + var service = createService(client); + + var model = new ElserInternalModel( + "my-elser", + TaskType.SPARSE_EMBEDDING, + "elser", + new ElserInternalServiceSettings(1, 1, ".elser_model_2", null), + ElserMlNodeTaskSettings.DEFAULT + ); + + service.putModel(model, new ActionListener<>() { + @Override + public void onResponse(Boolean success) { + assertTrue(success); + } + + @Override + public void onFailure(Exception e) { + fail(e); + } + }); + + var putConfig = argument.getValue().getTrainedModelConfig(); + assertEquals("text_field", putConfig.getInput().getFieldNames().get(0)); + } + + private ElserInternalService createService(Client client) { + var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); + return new ElserInternalService(context); + } +} From d02e44ef82b6401d7fcde39d866b076b15fb3191 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Thu, 19 Sep 2024 14:44:55 -0400 Subject: [PATCH 19/22] Add deprecation log message for elser service --- .../inference/InferenceServiceRegistry.java | 1 + .../ElasticsearchInternalServiceSettings.java | 1 - .../ElserInternalServiceTests.java | 548 ------------------ 3 files changed, 1 insertion(+), 549 deletions(-) delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index f1ce94173a550..4f2ba2f8b0d22 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -54,6 +54,7 @@ public Optional getService(String serviceName) { return Optional.ofNullable(services.get(serviceName)); } } + } public List getNamedWriteables() { return namedWriteables; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index 38ba49fedd418..f8b5837ef387e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -27,7 +27,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; public class ElasticsearchInternalServiceSettings implements ServiceSettings { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java deleted file mode 100644 index 72849d3539343..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceTests.java +++ /dev/null @@ -1,548 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elasticsearch; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.ChunkingOptions; -import org.elasticsearch.inference.InferenceResults; -import org.elasticsearch.inference.InferenceServiceExtension; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; -import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; -import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceChunkedTextExpansionResultsTests; -import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; -import org.elasticsearch.xpack.inference.services.elser.ElserModels; -import org.junit.After; -import org.junit.Before; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; - -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.instanceOf; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class ElserInternalServiceTests extends ESTestCase { - - private static ThreadPool threadPool; - - @Before - public void setUpThreadPool() { - threadPool = new TestThreadPool("test"); - } - - @After - public void shutdownThreadPool() { - TestThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); - } - - public static Model randomModelConfig(String inferenceEntityId, TaskType taskType) { - return switch (taskType) { - case SPARSE_EMBEDDING -> new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( - inferenceEntityId, - taskType, - ElserInternalService.NAME, - ElserInternalServiceSettingsTests.createRandom(), - ElserMlNodeTaskSettingsTests.createRandom() - ); - default -> throw new IllegalArgumentException("task type " + taskType + " is not supported"); - }; - } - - public void testParseConfigStrict() { - var service = createService(mock(Client.class)); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4, - "model_id", - ".elser_model_1" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - - var expectedModel = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 4, ".elser_model_1", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var modelVerificationListener = getModelVerificationListener(expectedModel); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelVerificationListener); - - } - - public void testParseConfigLooseWithOldModelId() { - var service = createService(mock(Client.class)); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4, - "model_version", - ".elser_model_1" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - - var expectedModel = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 4, ".elser_model_1", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var realModel = service.parsePersistedConfig("foo", TaskType.SPARSE_EMBEDDING, settings); - - assertEquals(expectedModel, realModel); - - } - - private static ActionListener getModelVerificationListener( - org.elasticsearch.xpack.inference.services.elser.ElserInternalModel expectedModel - ) { - return ActionListener.wrap( - (model) -> { assertEquals(expectedModel, model); }, - (e) -> fail("Model verification should not fail " + e.getMessage()) - ); - } - - public void testParseConfigStrictWithNoTaskSettings() { - var service = createService(mock(Client.class)); - - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4 - ) - ) - ); - - var expectedModel = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - ElserInternalService.NAME, - new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null), - ElserMlNodeTaskSettings.DEFAULT - ); - - var modelVerificationListener = getModelVerificationListener(expectedModel); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelVerificationListener); - - } - - public void testParseConfigStrictWithUnknownSettings() { - - var service = createService(mock(Client.class)); - - for (boolean throwOnUnknown : new boolean[] { true, false }) { - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of()); - settings.put("foo", "bar"); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); - } - } - - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); - } - } - - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.MODEL_ID, - ".elser_model_2", - "foo", - "bar" - ) - ) - ); - settings.put(ModelConfigurations.TASK_SETTINGS, Map.of("foo", "bar")); - - ActionListener errorVerificationListener = ActionListener.wrap((model) -> { - if (throwOnUnknown) { - fail("Model verification should fail when throwOnUnknown is true"); - } - }, (e) -> { - if (throwOnUnknown) { - assertThat( - e.getMessage(), - containsString("Model configuration contains settings [{foo=bar}] unknown to the [elser] service") - ); - } else { - fail("Model verification should not fail when throwOnUnknown is false"); - } - }); - if (throwOnUnknown == false) { - var parsed = service.parsePersistedConfigWithSecrets( - "foo", - TaskType.SPARSE_EMBEDDING, - settings, - Collections.emptyMap() - ); - } else { - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), errorVerificationListener); - } - } - } - } - - public void testParseRequestConfig_DefaultModel() { - var service = createService(mock(Client.class)); - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4 - ) - ) - ); - - ActionListener modelActionListener = ActionListener.wrap((model) -> { - assertEquals( - ".elser_model_2", - ((org.elasticsearch.xpack.inference.services.elser.ElserInternalModel) model).getServiceSettings().modelId() - ); - }, (e) -> { fail("Model verification should not fail"); }); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of(), modelActionListener); - } - { - var settings = new HashMap(); - settings.put( - ModelConfigurations.SERVICE_SETTINGS, - new HashMap<>( - Map.of( - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_ALLOCATIONS, - 1, - org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings.NUM_THREADS, - 4 - ) - ) - ); - - ActionListener modelActionListener = ActionListener.wrap((model) -> { - assertEquals( - ".elser_model_2_linux-x86_64", - ((org.elasticsearch.xpack.inference.services.elser.ElserInternalModel) model).getServiceSettings().modelId() - ); - }, (e) -> { fail("Model verification should not fail"); }); - - service.parseRequestConfig("foo", TaskType.SPARSE_EMBEDDING, settings, Set.of("linux-x86_64"), modelActionListener); - } - } - - @SuppressWarnings("unchecked") - public void testChunkInfer() { - var mlTrainedModelResults = new ArrayList(); - mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); - mlTrainedModelResults.add(InferenceChunkedTextExpansionResultsTests.createRandomResults()); - mlTrainedModelResults.add(new ErrorInferenceResults(new RuntimeException("boom"))); - var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true); - - ThreadPool threadpool = new TestThreadPool("test"); - Client client = mock(Client.class); - when(client.threadPool()).thenReturn(threadpool); - doAnswer(invocationOnMock -> { - var listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(response); - return null; - }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class)); - - var model = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - "elser", - new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 1, "elser", null), - new ElserMlNodeTaskSettings() - ); - var service = createService(client); - - var gotResults = new AtomicBoolean(); - var resultsListener = ActionListener.>wrap(chunkedResponse -> { - assertThat(chunkedResponse, hasSize(3)); - assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result1 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(0); - assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(0)).getChunks(), result1.getChunkedResults()); - assertThat(chunkedResponse.get(1), instanceOf(InferenceChunkedSparseEmbeddingResults.class)); - var result2 = (InferenceChunkedSparseEmbeddingResults) chunkedResponse.get(1); - assertEquals(((MlChunkedTextExpansionResults) mlTrainedModelResults.get(1)).getChunks(), result2.getChunkedResults()); - var result3 = (ErrorChunkedInferenceResults) chunkedResponse.get(2); - assertThat(result3.getException(), instanceOf(RuntimeException.class)); - assertThat(result3.getException().getMessage(), containsString("boom")); - gotResults.set(true); - }, ESTestCase::fail); - - service.chunkedInfer( - model, - null, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - new ChunkingOptions(null, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.runAfter(resultsListener, () -> terminate(threadpool)) - ); - - if (gotResults.get() == false) { - terminate(threadpool); - } - assertTrue("Listener not called", gotResults.get()); - } - - @SuppressWarnings("unchecked") - public void testChunkInferSetsTokenization() { - var expectedSpan = new AtomicInteger(); - var expectedWindowSize = new AtomicReference(); - - ThreadPool threadpool = new TestThreadPool("test"); - Client client = mock(Client.class); - try { - when(client.threadPool()).thenReturn(threadpool); - doAnswer(invocationOnMock -> { - var request = (InferTrainedModelDeploymentAction.Request) invocationOnMock.getArguments()[1]; - assertThat(request.getUpdate(), instanceOf(TokenizationConfigUpdate.class)); - var update = (TokenizationConfigUpdate) request.getUpdate(); - assertEquals(update.getSpanSettings().span(), expectedSpan.get()); - assertEquals(update.getSpanSettings().maxSequenceLength(), expectedWindowSize.get()); - return null; - }).when(client) - .execute( - same(InferTrainedModelDeploymentAction.INSTANCE), - any(InferTrainedModelDeploymentAction.Request.class), - any(ActionListener.class) - ); - - var model = new org.elasticsearch.xpack.inference.services.elser.ElserInternalModel( - "foo", - TaskType.SPARSE_EMBEDDING, - "elser", - new org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings(1, 1, "elser", null), - new ElserMlNodeTaskSettings() - ); - var service = createService(client); - - expectedSpan.set(-1); - expectedWindowSize.set(null); - service.chunkedInfer( - model, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - null, - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) - ); - - expectedSpan.set(-1); - expectedWindowSize.set(256); - service.chunkedInfer( - model, - List.of("foo", "bar"), - Map.of(), - InputType.SEARCH, - new ChunkingOptions(256, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) - ); - } finally { - terminate(threadpool); - } - } - - @SuppressWarnings("unchecked") - public void testPutModel() { - var client = mock(Client.class); - ArgumentCaptor argument = ArgumentCaptor.forClass(PutTrainedModelAction.Request.class); - - doAnswer(invocation -> { - var listener = (ActionListener) invocation.getArguments()[2]; - listener.onResponse(new PutTrainedModelAction.Response(mock(TrainedModelConfig.class))); - return null; - }).when(client).execute(Mockito.same(PutTrainedModelAction.INSTANCE), argument.capture(), any()); - - when(client.threadPool()).thenReturn(threadPool); - - var service = createService(client); - - var model = new ElserInternalModel( - "my-elser", - TaskType.SPARSE_EMBEDDING, - "elser", - new ElserInternalServiceSettings(1, 1, ".elser_model_2", null), - ElserMlNodeTaskSettings.DEFAULT - ); - - service.putModel(model, new ActionListener<>() { - @Override - public void onResponse(Boolean success) { - assertTrue(success); - } - - @Override - public void onFailure(Exception e) { - fail(e); - } - }); - - var putConfig = argument.getValue().getTrainedModelConfig(); - assertEquals("text_field", putConfig.getInput().getFieldNames().get(0)); - } - - private ElserInternalService createService(Client client) { - var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); - return new ElserInternalService(context); - } -} From ed234c774c78ff1b15162ce2468bc84fcf23fde4 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Mon, 23 Sep 2024 10:23:37 -0400 Subject: [PATCH 20/22] Add deprecation warning to infer API for elser From 0a0e3c5cf5025bb852c418616da67f8e0995f081 Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Fri, 27 Sep 2024 13:03:50 -0400 Subject: [PATCH 21/22] fix merge conflicts --- .../org/elasticsearch/inference/InferenceServiceRegistry.java | 1 - .../services/elasticsearch/ElasticsearchInternalService.java | 2 -- .../xpack/inference/ModelConfigurationsTests.java | 4 ++-- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java index 4f2ba2f8b0d22..f1ce94173a550 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceRegistry.java @@ -54,7 +54,6 @@ public Optional getService(String serviceName) { return Optional.ofNullable(services.get(serviceName)); } } - } public List getNamedWriteables() { return namedWriteables; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 23a14da892ada..810d3084bec1a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -74,8 +74,6 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 ); - private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); - private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java index df3712ea344a2..03613901c7816 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/ModelConfigurationsTests.java @@ -14,10 +14,10 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings; public class ModelConfigurationsTests extends AbstractWireSerializingTestCase { From 703bc91d2d925b3925da008d640d8ec79b49188d Mon Sep 17 00:00:00 2001 From: Max Hniebergall Date: Wed, 2 Oct 2024 17:22:02 -0400 Subject: [PATCH 22/22] fix merge --- .../integration/ModelRegistryIT.java | 6 +- .../action/TransportInferenceAction.java | 3 - .../TransportPutInferenceModelAction.java | 30 +---- .../BaseElasticsearchInternalService.java | 1 - .../ElasticsearchInternalService.java | 37 +++++- .../ElasticsearchInternalServiceTests.java | 110 +++++++++--------- 6 files changed, 94 insertions(+), 93 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index c81c209241cc6..ea8b32f36f54c 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -28,8 +28,8 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettingsTests; import org.junit.Before; @@ -117,10 +117,10 @@ public void testGetModel() throws Exception { assertEquals(model.getConfigurations().getService(), modelHolder.get().service()); - var elserService = new ElserInternalService( + var elserService = new ElasticsearchInternalService( new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class), mock(ThreadPool.class)) ); - ElserInternalModel roundTripModel = elserService.parsePersistedConfigWithSecrets( + ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets( modelHolder.get().inferenceEntityId(), modelHolder.get().taskType(), modelHolder.get().settings(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 4a466103ba710..d2a73b7df77c1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.ChunkedToXContent; @@ -27,14 +26,12 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; import java.util.Set; import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; public class TransportInferenceAction extends HandledTransportAction { private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index fef4c4fef6974..49d65b6e0dc59 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -19,9 +19,6 @@ import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.logging.DeprecationCategory; -import org.elasticsearch.common.logging.DeprecationLogger; -import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.XContentHelper; @@ -47,20 +44,17 @@ import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import java.io.IOException; +import java.util.List; import java.util.Map; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.services.elasticsearch.BaseElasticsearchInternalService.selectDefaultModelVariantBasedOnClusterArchitecture; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; -import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL; -import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL_LINUX_X86; public class TransportPutInferenceModelAction extends TransportMasterNodeAction< PutInferenceModelAction.Request, PutInferenceModelAction.Response> { private static final Logger logger = LogManager.getLogger(TransportPutInferenceModelAction.class); - private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(PutInferenceModelAction.class); private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; @@ -119,6 +113,10 @@ protected void masterOperation( return; } + if (List.of(OLD_ELSER_SERVICE_NAME, ElasticsearchInternalService.NAME).contains(serviceName)) { + // required for BWC of elser service in elasticsearch service TODO remove when elser service deprecated + requestAsMap.put(ModelConfigurations.SERVICE, serviceName); + } var service = serviceRegistry.getService(serviceName); if (service.isEmpty()) { listener.onFailure(new ElasticsearchStatusException("Unknown service [{}]", RestStatus.BAD_REQUEST, serviceName)); @@ -160,24 +158,6 @@ protected void masterOperation( return; } - if (serviceName.equals(OLD_ELSER_SERVICE_NAME)) { - String modelId = selectDefaultModelVariantBasedOnClusterArchitecture( - architectures, - ELSER_V2_MODEL_LINUX_X86, - ELSER_V2_MODEL - ); - - DEPRECATION_LOGGER.warn( - DeprecationCategory.API, - "inference_api_elser_service", - "The [{}] service is deprecated and will be removed in a future release. Use the [{}] service instead, with" - + " [model_id] set to [{}] in the [service_settings]", - OLD_ELSER_SERVICE_NAME, - ElasticsearchInternalService.NAME, - modelId - ); - } - parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, listener); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index 1dd7a36315c19..23e806e01300a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -32,7 +32,6 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil; import org.elasticsearch.xpack.inference.InferencePlugin; -import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel; import java.io.IOException; import java.util.EnumSet; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 22f586597e3f1..e274c641e30be 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -13,7 +13,6 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.core.Nullable; @@ -61,7 +60,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL_LINUX_X86; -import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.VALID_ELSER_MODEL_IDS; public class ElasticsearchInternalService extends BaseElasticsearchInternalService { @@ -120,13 +118,26 @@ public void parseRequestConfig( "Putting elasticsearch service inference endpoints (including elser service) without a model_id field is" + " deprecated and will be removed in a future release. Please specify a model_id field." ); - elserCase(inferenceEntityId, taskType, config, platformArchitectures, serviceSettingsMap, modelListener); } - if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) { + platformArch.accept( + modelListener.delegateFailureAndWrap( + (delegate, arch) -> elserCase(inferenceEntityId, taskType, config, arch, serviceSettingsMap, modelListener) + ) + ); + } else { + throw new IllegalArgumentException("Error parsing service settings, model_id must be provided"); + } + } else if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) { platformArch.accept( modelListener.delegateFailureAndWrap( (delegate, arch) -> e5Case(inferenceEntityId, taskType, config, arch, serviceSettingsMap, modelListener) ) ); + } else if (ElserModels.isValidModel(modelId)) { + platformArch.accept( + modelListener.delegateFailureAndWrap( + (delegate, arch) -> elserCase(inferenceEntityId, taskType, config, arch, serviceSettingsMap, modelListener) + ) + ); } else { customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, modelListener); } @@ -259,6 +270,14 @@ static boolean modelVariantValidForArchitecture(Set platformArchitecture // platform agnostic model is always compatible return true; } + return modelId.equals( + selectDefaultModelVariantBasedOnClusterArchitecture( + platformArchitectures, + MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86, + MULTILINGUAL_E5_SMALL_MODEL_ID + ) + ); + } private void elserCase( String inferenceEntityId, @@ -295,6 +314,16 @@ private void elserCase( } } + DEPRECATION_LOGGER.warn( + DeprecationCategory.API, + "inference_api_elser_service", + "The [{}] service is deprecated and will be removed in a future release. Use the [{}] service instead, with" + + " [model_id] set to [{}] in the [service_settings]", + OLD_ELSER_SERVICE_NAME, + ElasticsearchInternalService.NAME, + defaultModelId + ); + if (modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(platformArchitectures, esServiceSettingsBuilder.getModelId())) { throw new IllegalArgumentException( "Error parsing request config, model id does not match any models available on this platform. Was [" diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 734ac8089d3c5..cd6da4c0ad8d8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -9,10 +9,12 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; +import org.apache.logging.log4j.Level; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -67,10 +69,10 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME; -import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME; +import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -116,7 +118,7 @@ public void testParseRequestConfig() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, config, modelListener); } public void testParseRequestConfig_Misconfigured() { @@ -138,7 +140,7 @@ public void testParseRequestConfig_Misconfigured() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, config, modelListener); } // Invalid config map @@ -160,7 +162,7 @@ public void testParseRequestConfig_Misconfigured() { ); var taskType = randomFrom(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.SPARSE_EMBEDDING); - service.parseRequestConfig(randomInferenceEntityId, taskType, settings, modelListener); + service.parseRequestConfig(randomInferenceEntityId, taskType, config, modelListener); } } @@ -188,7 +190,7 @@ public void testParseRequestConfig_E5() { randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings, - getModelVerificationActionListener(e5ServiceSettings) + getE5ModelVerificationActionListener(e5ServiceSettings) ); } @@ -220,7 +222,7 @@ public void testParseRequestConfig_E5() { randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings, - getModelVerificationActionListener(e5ServiceSettings) + getE5ModelVerificationActionListener(e5ServiceSettings) ); } @@ -249,14 +251,16 @@ public void testParseRequestConfig_E5() { e -> assertThat(e, instanceOf(ElasticsearchStatusException.class)) ); - service.parseRequestConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, config, modelListener); + service.parseRequestConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings, modelListener); } } public void testParseRequestConfig_elser() { // General happy case { - var service = createService(mock(Client.class)); + Client mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(threadPool); + var service = createService(mockClient); var config = new HashMap(); config.put(ModelConfigurations.SERVICE, OLD_ELSER_SERVICE_NAME); config.put( @@ -279,14 +283,20 @@ public void testParseRequestConfig_elser() { randomInferenceEntityId, TaskType.SPARSE_EMBEDDING, config, - Set.of(), - getElserModelVerificationActionListener(elserServiceSettings, null) + getElserModelVerificationActionListener( + elserServiceSettings, + null, + "The [elser] service is deprecated and will be removed in a future release. Use the [elasticsearch] service " + + "instead, with [model_id] set to [.elser_model_2] in the [service_settings]" + ) ); } // null model ID returns elser model for the provided platform (not linux) { - var service = createService(mock(Client.class)); + Client mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(threadPool); + var service = createService(mockClient); var config = new HashMap(); config.put(ModelConfigurations.SERVICE, OLD_ELSER_SERVICE_NAME); config.put( @@ -298,23 +308,26 @@ public void testParseRequestConfig_elser() { var elserServiceSettings = new ElserInternalServiceSettings(1, 4, ElserModels.ELSER_V2_MODEL, null); + String criticalWarning = + "Putting elasticsearch service inference endpoints (including elser service) without a model_id field is" + + " deprecated and will be removed in a future release. Please specify a model_id field."; + String warnWarning = + "The [elser] service is deprecated and will be removed in a future release. Use the [elasticsearch] service " + + "instead, with [model_id] set to [.elser_model_2] in the [service_settings]"; service.parseRequestConfig( randomInferenceEntityId, TaskType.SPARSE_EMBEDDING, config, - Set.of("not-linux"), - getElserModelVerificationActionListener( - elserServiceSettings, - new String[] { - "Putting elasticsearch service inference endpoints (including elser service) without a model_id field is deprecated" - + " and will be removed in a future release. Please specify a model_id field." } - ) + getElserModelVerificationActionListener(elserServiceSettings, criticalWarning, warnWarning) ); + assertWarnings(true, new DeprecationWarning(DeprecationLogger.CRITICAL, criticalWarning)); } // Invalid service settings { - var service = createService(mock(Client.class)); + Client mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(threadPool); + var service = createService(mockClient); var config = new HashMap(); config.put(ModelConfigurations.SERVICE, OLD_ELSER_SERVICE_NAME); config.put( @@ -338,7 +351,7 @@ public void testParseRequestConfig_elser() { e -> assertThat(e, instanceOf(ElasticsearchStatusException.class)) ); - service.parseRequestConfig(randomInferenceEntityId, TaskType.SPARSE_EMBEDDING, config, Set.of(), modelListener); + service.parseRequestConfig(randomInferenceEntityId, TaskType.SPARSE_EMBEDDING, config, modelListener); } } @@ -485,43 +498,26 @@ private ActionListener getE5ModelVerificationActionListener(MultilingualE private ActionListener getElserModelVerificationActionListener( ElserInternalServiceSettings elserServiceSettings, - String[] warnings + String criticalWarning, + String warnWarning ) { - if (warnings != null) { - - return ActionListener.wrap(model -> { - assertCriticalWarnings( - "Putting elasticsearch service inference endpoints (including elser service) without a model_id field" - + " is deprecated and will be removed in a future release. Please specify a model_id field." - ); - assertEquals( - new ElserInternalModel( - randomInferenceEntityId, - TaskType.SPARSE_EMBEDDING, - NAME, - elserServiceSettings, - ElserMlNodeTaskSettings.DEFAULT - ), - model - ); - }, e -> { fail("Model parsing failed " + e.getMessage()); }); - - } else { - - return ActionListener.wrap(model -> { - assertEquals( - new ElserInternalModel( - randomInferenceEntityId, - TaskType.SPARSE_EMBEDDING, - NAME, - elserServiceSettings, - ElserMlNodeTaskSettings.DEFAULT - ), - model - ); - }, e -> { fail("Model parsing failed " + e.getMessage()); }); - - } + return ActionListener.wrap(model -> { + assertWarnings( + true, + new DeprecationWarning(DeprecationLogger.CRITICAL, criticalWarning), + new DeprecationWarning(Level.WARN, warnWarning) + ); + assertEquals( + new ElserInternalModel( + randomInferenceEntityId, + TaskType.SPARSE_EMBEDDING, + NAME, + elserServiceSettings, + ElserMlNodeTaskSettings.DEFAULT + ), + model + ); + }, e -> { fail("Model parsing failed " + e.getMessage()); }); } public void testParsePersistedConfig() {