diff --git a/docs/changelog/137434.yaml b/docs/changelog/137434.yaml new file mode 100644 index 0000000000000..cb8bb32389de5 --- /dev/null +++ b/docs/changelog/137434.yaml @@ -0,0 +1,5 @@ +pr: 137434 +summary: Require basic licence for the Elastic Inference Service +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackField.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackField.java index ca55031a1f5c9..5ace318bedcd6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackField.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackField.java @@ -32,6 +32,8 @@ public final class XPackField { public static final String UPGRADE = "upgrade"; // inside of YAML settings we still use xpack do not having handle issues with dashes public static final String SETTINGS_NAME = "xpack"; + /** Name constant for the EIS feature. */ + public static final String ELASTIC_INFERENCE_SERVICE = "Elastic Inference Service"; /** Name constant for the eql feature. */ public static final String EQL = "eql"; /** Name constant for the esql feature. */ diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java index 09834e6a91210..63d326016a5ec 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -34,7 +34,7 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase { private static ElasticsearchCluster cluster = ElasticsearchCluster.local() .distribution(DistributionType.DEFAULT) - .setting("xpack.license.self_generated.type", "trial") + .setting("xpack.license.self_generated.type", "basic") .setting("xpack.security.enabled", "true") // Adding both settings unless one feature flag is disabled in a particular environment .setting("xpack.inference.elastic.url", mockEISServer::getUrl) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBasicLicenseIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBasicLicenseIT.java index 4400ad8bbb538..e4d2fb5bc29bc 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBasicLicenseIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBasicLicenseIT.java @@ -45,28 +45,6 @@ public void testPutModel_RestrictedWithBasicLicense() throws Exception { sendRestrictedRequest("PUT", endpoint, modelConfig); } - public void testUpdateModel_RestrictedWithBasicLicense() throws Exception { - var endpoint = Strings.format("_inference/%s/%s/_update?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id"); - var requestBody = """ - { - "task_settings": { - "num_threads": 2 - } - } - """; - sendRestrictedRequest("PUT", endpoint, requestBody); - } - - public void testPerformInference_RestrictedWithBasicLicense() throws Exception { - var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id"); - var requestBody = """ - { - "input": ["washing", "machine"] - } - """; - sendRestrictedRequest("POST", endpoint, requestBody); - } - public void testGetServices_NonRestrictedWithBasicLicense() throws Exception { var endpoint = "_inference/_services"; sendNonRestrictedRequest("GET", endpoint, null, 200, false); diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java index 30b8a636b9ac6..fdaac19479d74 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java @@ -46,6 +46,8 @@ @ESTestCase.WithoutEntitlements // due to dependency issue ES-12435 public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase { public static final String INDEX_NAME = "test-index"; + private static final String SPARSE_INFERENCE_ID = "sparse-endpoint"; + private static final String DENSE_INFERENCE_ID = "dense-endpoint"; private final boolean useLegacyFormat; @@ -61,9 +63,9 @@ public static Iterable parameters() { @Before public void setup() throws Exception { ModelRegistry modelRegistry = internalCluster().getCurrentMasterNodeInstance(ModelRegistry.class); - Utils.storeSparseModel("sparse-endpoint", modelRegistry); + Utils.storeSparseModel(SPARSE_INFERENCE_ID, modelRegistry); Utils.storeDenseModel( - "dense-endpoint", + DENSE_INFERENCE_ID, modelRegistry, randomIntBetween(1, 100), // dot product means that we need normalized vectors; it's not worth doing that in this test @@ -92,27 +94,20 @@ public Settings indexSettings() { } public void testLicenseInvalidForInference() { - prepareCreate(INDEX_NAME).setMapping( - String.format( - Locale.ROOT, - """ - { - "properties": { - "sparse_field": { - "type": "semantic_text", - "inference_id": "%s" - }, - "dense_field": { - "type": "semantic_text", - "inference_id": "%s" - } - } + prepareCreate(INDEX_NAME).setMapping(String.format(Locale.ROOT, """ + { + "properties": { + "sparse_field": { + "type": "semantic_text", + "inference_id": "%s" + }, + "dense_field": { + "type": "semantic_text", + "inference_id": "%s" } - """, - TestSparseInferenceServiceExtension.TestInferenceService.NAME, - TestDenseInferenceServiceExtension.TestInferenceService.NAME - ) - ).get(); + } + } + """, SPARSE_INFERENCE_ID, DENSE_INFERENCE_ID)).get(); BulkRequestBuilder bulkRequest = client().prepareBulk(); int totalBulkReqs = randomIntBetween(2, 100); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceLicenceCheck.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceLicenceCheck.java new file mode 100644 index 0000000000000..c9e97617546c7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceLicenceCheck.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; + +import static org.elasticsearch.xpack.inference.InferencePlugin.EIS_INFERENCE_FEATURE; +import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; + +public class InferenceLicenceCheck { + + private InferenceLicenceCheck() {} + + public static boolean isServiceLicenced(String serviceName, XPackLicenseState licenseState) { + if (ElasticInferenceService.NAME.equals(serviceName)) { + return EIS_INFERENCE_FEATURE.check(licenseState); + } else { + return INFERENCE_API_FEATURE.check(licenseState); + } + } + + public static ElasticsearchSecurityException complianceException(String serviceName) { + if (ElasticInferenceService.NAME.equals(serviceName)) { + return LicenseUtils.newComplianceException(XPackField.ELASTIC_INFERENCE_SERVICE); + } else { + return LicenseUtils.newComplianceException(XPackField.INFERENCE); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 60592c5dd1dbd..2b071b2c5f4fd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -203,6 +203,12 @@ public class InferencePlugin extends Plugin License.OperationMode.ENTERPRISE ); + public static final LicensedFeature.Momentary EIS_INFERENCE_FEATURE = LicensedFeature.momentary( + "inference", + "Elastic Inference Service", + License.OperationMode.BASIC + ); + public static final String X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER = "X-elastic-product-use-case"; public static final String NAME = "inference"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 7e3be2e1a08c2..5540f2798a0cd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -25,15 +25,14 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.telemetry.InferenceStats; -import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.InferenceLicenceCheck; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.registry.InferenceEndpointRegistry; @@ -51,7 +50,6 @@ import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.inference.telemetry.InferenceStats.responseAttributes; import static org.elasticsearch.inference.telemetry.InferenceStats.serviceAndResponseAttributes; -import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; /** * Base class for transport actions that handle inference requests. @@ -112,16 +110,17 @@ protected abstract void doInference( @Override protected void doExecute(Task task, Request request, ActionListener listener) { - if (INFERENCE_API_FEATURE.check(licenseState) == false) { - listener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE)); - return; - } var timer = InferenceTimer.start(); var getModelListener = ActionListener.wrap((Model model) -> { var serviceName = model.getConfigurations().getService(); + if (InferenceLicenceCheck.isServiceLicenced(serviceName, licenseState) == false) { + listener.onFailure(InferenceLicenceCheck.complianceException(serviceName)); + return; + } + try { validateRequest(request, model); } catch (Exception e) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index f7a563d9bfed9..b472beebb66c5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -14,7 +14,6 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.master.TransportMasterNodeAction; import org.elasticsearch.client.internal.Client; -import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; @@ -34,7 +33,6 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; import org.elasticsearch.injection.guice.Inject; -import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -42,11 +40,11 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.inference.InferenceLicenceCheck; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.ServiceUtils; @@ -60,8 +58,6 @@ import java.util.Set; import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; -import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; import static org.elasticsearch.xpack.inference.common.SemanticTextInfoExtractor.getModelSettingsForIndicesReferencingInferenceEndpoints; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.canMergeModelSettings; @@ -76,7 +72,6 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction< private final XPackLicenseState licenseState; private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; - private final OriginSettingClient client; private volatile boolean skipValidationAndStart; private final ProjectResolver projectResolver; @@ -110,7 +105,6 @@ public TransportPutInferenceModelAction( clusterService.getClusterSettings() .addSettingsUpdateConsumer(InferencePlugin.SKIP_VALIDATE_AND_START, this::setSkipValidationAndStart); this.projectResolver = projectResolver; - this.client = new OriginSettingClient(client, INFERENCE_ORIGIN); } @Override @@ -120,11 +114,6 @@ protected void masterOperation( ClusterState state, ActionListener listener ) throws Exception { - if (INFERENCE_API_FEATURE.check(licenseState) == false) { - listener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE)); - return; - } - if (modelRegistry.containsDefaultConfigId(request.getInferenceEntityId())) { listener.onFailure( new ElasticsearchStatusException( @@ -150,6 +139,11 @@ protected void masterOperation( return; } + if (InferenceLicenceCheck.isServiceLicenced(serviceName, licenseState) == false) { + listener.onFailure(InferenceLicenceCheck.complianceException(serviceName)); + return; + } + if (List.of(OLD_ELSER_SERVICE_NAME, ElasticsearchInternalService.NAME).contains(serviceName)) { // required for BWC of elser service in elasticsearch service TODO remove when elser service deprecated requestAsMap.put(ModelConfigurations.SERVICE, serviceName); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java index 884de4d6461ed..359cfc70546c1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelAction.java @@ -35,7 +35,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; -import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -43,13 +42,13 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.inference.InferenceLicenceCheck; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; @@ -61,7 +60,6 @@ import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; -import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; import static org.elasticsearch.xpack.inference.services.ServiceUtils.resolveTaskType; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS; @@ -113,11 +111,6 @@ protected void masterOperation( ClusterState state, ActionListener masterListener ) { - if (INFERENCE_API_FEATURE.check(licenseState) == false) { - masterListener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE)); - return; - } - var bodyTaskType = request.getContentAsSettings().taskType(); var resolvedTaskType = resolveTaskType(request.getTaskType(), bodyTaskType != null ? bodyTaskType.toString() : null); @@ -137,10 +130,16 @@ protected void masterOperation( unparsedModel.service() ) ); - } else { - service.set(optionalService.get()); - listener.onResponse(unparsedModel); + return; } + + if (InferenceLicenceCheck.isServiceLicenced(optionalService.get().name(), licenseState) == false) { + listener.onFailure(InferenceLicenceCheck.complianceException(optionalService.get().name())); + return; + } + + service.set(optionalService.get()); + listener.onResponse(unparsedModel); }) .andThen((listener, existingUnparsedModel) -> { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 1e7fd71c9da28..240d4863c7a18 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -48,7 +48,6 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.telemetry.InferenceStats; -import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; @@ -57,10 +56,10 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.InferenceException; +import org.elasticsearch.xpack.inference.InferenceLicenceCheck; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils; @@ -78,7 +77,6 @@ import java.util.stream.Collectors; import static org.elasticsearch.inference.telemetry.InferenceStats.serviceAndResponseAttributes; -import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy; @@ -383,6 +381,17 @@ public void onFailure(Exception exc) { modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); return; } + + if (InferenceLicenceCheck.isServiceLicenced(inferenceProvider.service.name(), licenseState) == false) { + try (onFinish) { + var complianceException = InferenceLicenceCheck.complianceException(inferenceProvider.service.name()); + for (FieldInferenceRequest request : requests) { + addInferenceResponseFailure(request.bulkItemIndex, complianceException); + } + return; + } + } + final List inputs = requests.stream() .map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings)) .collect(Collectors.toList()); @@ -571,11 +580,6 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< break; } - if (INFERENCE_API_FEATURE.check(licenseState) == false) { - addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE)); - break; - } - List requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); int offsetAdjustment = 0; for (String v : values) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InferenceLicenceCheckTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InferenceLicenceCheckTests.java new file mode 100644 index 0000000000000..b2311ecb7dfa4 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InferenceLicenceCheckTests.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.license.MockLicenseState; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.Mockito.when; + +public class InferenceLicenceCheckTests extends ESTestCase { + public void testIsServiceLicenced_WithElasticInferenceService_WhenLicensed() { + var licenseState = MockLicenseState.createMock(); + when(licenseState.isAllowed(InferencePlugin.EIS_INFERENCE_FEATURE)).thenReturn(true); + + assertTrue(InferenceLicenceCheck.isServiceLicenced(ElasticInferenceService.NAME, licenseState)); + } + + public void testIsServiceLicenced_WithElasticInferenceService_WhenNotLicensed() { + var licenseState = MockLicenseState.createMock(); + when(licenseState.isAllowed(InferencePlugin.EIS_INFERENCE_FEATURE)).thenReturn(false); + + assertFalse(InferenceLicenceCheck.isServiceLicenced(ElasticInferenceService.NAME, licenseState)); + } + + public void testIsServiceLicenced_WithOtherService_WhenLicensed() { + var licenseState = MockLicenseState.createMock(); + when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(true); + boolean result = InferenceLicenceCheck.isServiceLicenced("openai", licenseState); + assertTrue(result); + } + + public void testIsServiceLicenced_WithOtherService_WhenNotLicensed() { + var licenseState = MockLicenseState.createMock(); + when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(false); + boolean result = InferenceLicenceCheck.isServiceLicenced("cohere", licenseState); + assertFalse(result); + } + + public void testIsServiceLicenced_WithMultipleServices() { + var licenseState = MockLicenseState.createMock(); + when(licenseState.isAllowed(InferencePlugin.EIS_INFERENCE_FEATURE)).thenReturn(true); + when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(false); + + // Elastic Inference Service should be licensed + assertTrue(InferenceLicenceCheck.isServiceLicenced(ElasticInferenceService.NAME, licenseState)); + + // Other services should not be licensed + assertFalse(InferenceLicenceCheck.isServiceLicenced("openai", licenseState)); + assertFalse(InferenceLicenceCheck.isServiceLicenced("huggingface", licenseState)); + } + + public void testComplianceException_WithElasticInferenceService() { + ElasticsearchSecurityException exception = InferenceLicenceCheck.complianceException(ElasticInferenceService.NAME); + + assertNotNull(exception); + assertThat(exception.getMessage(), containsString("current license is non-compliant for [Elastic Inference Service]")); + } + + public void testComplianceException_WithOtherService() { + ElasticsearchSecurityException exception = InferenceLicenceCheck.complianceException("openai"); + + assertNotNull(exception); + assertThat(exception.getMessage(), containsString("current license is non-compliant for [inference]")); + } + + public void testComplianceException_WithMultipleServices() { + // Test that different services return different exceptions + ElasticsearchSecurityException eisException = InferenceLicenceCheck.complianceException(ElasticInferenceService.NAME); + ElasticsearchSecurityException openaiException = InferenceLicenceCheck.complianceException("openai"); + ElasticsearchSecurityException cohereException = InferenceLicenceCheck.complianceException("cohere"); + + assertThat(eisException.getMessage(), containsString("Elastic Inference Service")); + assertThat(openaiException.getMessage(), containsString("inference")); + assertThat(cohereException.getMessage(), containsString("inference")); + + // The two non-EIS exceptions should have the same message + assertEquals(openaiException.getMessage(), cohereException.getMessage()); + } + + public void testComplianceException_WithNullServiceName() { + ElasticsearchSecurityException exception = InferenceLicenceCheck.complianceException(null); + + assertNotNull(exception); + assertThat(exception.getMessage(), containsString("current license is non-compliant for [inference]")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelActionTests.java new file mode 100644 index 0000000000000..1d0b8806bc2b4 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUpdateInferenceModelActionTests.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.project.TestProjectResolvers; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.license.MockLicenseState; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.junit.Before; + +import java.util.Map; +import java.util.Optional; + +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TransportUpdateInferenceModelActionTests extends ESTestCase { + + private MockLicenseState licenseState; + private TransportUpdateInferenceModelAction action; + private ThreadPool threadPool; + private ModelRegistry mockModelRegistry; + private InferenceServiceRegistry mockInferenceServiceRegistry; + + @Before + public void createAction() throws Exception { + super.setUp(); + threadPool = mock(ThreadPool.class); + mockModelRegistry = mock(ModelRegistry.class); + mockInferenceServiceRegistry = mock(InferenceServiceRegistry.class); + licenseState = MockLicenseState.createMock(); + action = new TransportUpdateInferenceModelAction( + mock(TransportService.class), + mock(ClusterService.class), + threadPool, + mock(ActionFilters.class), + licenseState, + mockModelRegistry, + mockInferenceServiceRegistry, + mock(Client.class), + TestProjectResolvers.DEFAULT_PROJECT_ONLY + ); + + } + + public void testLicenseCheck_NotAllowed() { + mocks("enterprise_licensed_service", false); + + var listener = new PlainActionFuture(); + + String requestBody = "{\"service_settings\": {\"api_key\": \"\"}}"; + + action.masterOperation( + mock(Task.class), + new UpdateInferenceModelAction.Request( + "model-id", + new BytesArray(requestBody), + XContentType.JSON, + TaskType.TEXT_EMBEDDING, + TimeValue.timeValueSeconds(1) + ), + ClusterState.EMPTY_STATE, + listener + ); + + var exception = expectThrows(ElasticsearchSecurityException.class, () -> listener.actionGet(TimeValue.timeValueSeconds(5))); + assertThat(exception.getMessage(), is("current license is non-compliant for [inference]")); + } + + private void mocks(String serviceName, boolean isAllowed) { + doAnswer(invocationOnMock -> { + ActionListener listener = invocationOnMock.getArgument(1); + listener.onResponse(new UnparsedModel("model_id", TaskType.COMPLETION, serviceName, Map.of(), Map.of())); + return Void.TYPE; + }).when(mockModelRegistry).getModelWithSecrets(anyString(), any()); + + var mockService = mock(InferenceService.class); + when(mockService.name()).thenReturn(serviceName); + when(mockInferenceServiceRegistry.getService(anyString())).thenReturn(Optional.of(mockService)); + + when(licenseState.isAllowed(any())).thenReturn(isAllowed); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 715198e8c7804..070a6dc8a9538 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -41,6 +41,7 @@ import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; @@ -48,6 +49,7 @@ import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.inference.telemetry.InferenceStats; @@ -69,6 +71,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.junit.After; import org.junit.Before; import org.mockito.stubbing.Answer; @@ -151,14 +154,7 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); - ShardBulkInferenceActionFilter filter = createFilter( - threadPool, - Map.of(), - NOOP_INDEXING_PRESSURE, - useLegacyFormat, - true, - inferenceStats - ); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, inferenceStats); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -185,12 +181,70 @@ public void testFilterNoop() throws Exception { public void testLicenseInvalidForInference() throws InterruptedException { final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); StaticModel model = StaticModel.createRandomInstance(); + var licenseState = MockLicenseState.createMock(); + when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(false); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(model.getInferenceEntityId(), model), + NOOP_INDEXING_PRESSURE, + useLegacyFormat, + licenseState, + inferenceStats + ); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertThat(bulkShardRequest.items().length, equalTo(1)); + + BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure(); + assertNotNull(failure); + assertThat(failure.getCause(), instanceOf(ElasticsearchSecurityException.class)); + assertThat( + failure.getMessage(), + containsString(org.elasticsearch.core.Strings.format("current license is non-compliant for [%s]", XPackField.INFERENCE)) + ); + } finally { + chainExecuted.countDown(); + } + + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + Map inferenceFieldMap = Map.of( + "obj.field1", + new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }, null) + ); + BulkItemRequest[] items = new BulkItemRequest[1]; + items[0] = new BulkItemRequest(0, new IndexRequest("test").source("obj.field1", "Test")); + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testLicenseInvalidForEis() throws InterruptedException { + final InferenceStats inferenceStats = InferenceStatsTests.mockInferenceStats(); + StaticModel model = new StaticModel( + randomAlphanumericOfLength(5), + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + new TestModel.TestServiceSettings("foo", 128, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.BYTE), + new TestModel.TestTaskSettings(randomInt(3)), + new TestModel.TestSecretSettings(randomAlphaOfLength(4)) + ); + + var licenseState = MockLicenseState.createMock(); + when(licenseState.isAllowed(InferencePlugin.EIS_INFERENCE_FEATURE)).thenReturn(false); ShardBulkInferenceActionFilter filter = createFilter( threadPool, - Map.of(), + Map.of(model.getInferenceEntityId(), model), NOOP_INDEXING_PRESSURE, useLegacyFormat, - false, + licenseState, inferenceStats ); CountDownLatch chainExecuted = new CountDownLatch(1); @@ -236,7 +290,6 @@ public void testInferenceNotFound() throws Exception { Map.of(model.getInferenceEntityId(), model), NOOP_INDEXING_PRESSURE, useLegacyFormat, - true, inferenceStats ); CountDownLatch chainExecuted = new CountDownLatch(1); @@ -284,7 +337,6 @@ public void testItemFailures() throws Exception { Map.of(model.getInferenceEntityId(), model), NOOP_INDEXING_PRESSURE, useLegacyFormat, - true, inferenceStats ); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); @@ -375,7 +427,6 @@ public void testExplicitNull() throws Exception { Map.of(model.getInferenceEntityId(), model), NOOP_INDEXING_PRESSURE, useLegacyFormat, - true, inferenceStats ); @@ -448,7 +499,6 @@ public void testHandleEmptyInput() throws Exception { Map.of(model.getInferenceEntityId(), model), NOOP_INDEXING_PRESSURE, useLegacyFormat, - true, inferenceStats ); @@ -526,7 +576,6 @@ public void testManyRandomDocs() throws Exception { inferenceModelMap, NOOP_INDEXING_PRESSURE, useLegacyFormat, - true, inferenceStats ); CountDownLatch chainExecuted = new CountDownLatch(1); @@ -569,7 +618,6 @@ public void testIndexingPressure() throws Exception { Map.of(sparseModel.getInferenceEntityId(), sparseModel, denseModel.getInferenceEntityId(), denseModel), indexingPressure, useLegacyFormat, - true, inferenceStats ); @@ -688,7 +736,6 @@ public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Excep Map.of(sparseModel.getInferenceEntityId(), sparseModel), indexingPressure, useLegacyFormat, - true, inferenceStats ); @@ -775,7 +822,6 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except Map.of(sparseModel.getInferenceEntityId(), sparseModel), indexingPressure, useLegacyFormat, - true, inferenceStats ); @@ -888,7 +934,6 @@ public void testIndexingPressurePartialFailure() throws Exception { Map.of(sparseModel.getInferenceEntityId(), sparseModel), indexingPressure, useLegacyFormat, - true, inferenceStats ); @@ -961,13 +1006,25 @@ public void testIndexingPressurePartialFailure() throws Exception { verify(coordinatingIndexingPressure).close(); } + private static ShardBulkInferenceActionFilter createFilter( + ThreadPool threadPool, + Map modelMap, + IndexingPressure indexingPressure, + boolean useLegacyFormat, + InferenceStats inferenceStats + ) { + MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(true); + return createFilter(threadPool, modelMap, indexingPressure, useLegacyFormat, licenseState, inferenceStats); + } + @SuppressWarnings("unchecked") private static ShardBulkInferenceActionFilter createFilter( ThreadPool threadPool, Map modelMap, IndexingPressure indexingPressure, boolean useLegacyFormat, - boolean isLicenseValidForInference, + MockLicenseState licenseState, InferenceStats inferenceStats ) { ModelRegistry modelRegistry = mock(ModelRegistry.class); @@ -1037,9 +1094,6 @@ private static ShardBulkInferenceActionFilter createFilter( InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); - MockLicenseState licenseState = MockLicenseState.createMock(); - when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(isLicenseValidForInference); - return new ShardBulkInferenceActionFilter( createClusterService(useLegacyFormat), inferenceServiceRegistry,