From fb621c404203d94844e8e6d99cbc9ee054e2c6c5 Mon Sep 17 00:00:00 2001 From: Dan Rubinstein Date: Thu, 20 Feb 2025 14:06:40 -0500 Subject: [PATCH] Add enterprise license check to inference action for semantic text fields (#122293) * Add enterprise license check to inference action for semantic text fields * Update docs/changelog/122293.yaml * Set license to trial in ShardBulkInferenceActionFilterIT * Move license check to only block semantic_text fields that require inference call * Cleaning up tests * Add parameterization on useLegacyFormat back in ShardBulkInferenceActionFilterBasicLicenseIT --------- Co-authored-by: Elastic Machine --- docs/changelog/122293.yaml | 5 + ...lkInferenceActionFilterBasicLicenseIT.java | 168 ++++++++++++++++++ .../ShardBulkInferenceActionFilterIT.java | 6 + .../xpack/inference/InferencePlugin.java | 2 +- .../ShardBulkInferenceActionFilter.java | 18 +- .../ShardBulkInferenceActionFilterTests.java | 68 ++++++- ...emanticTextNonDynamicFieldMapperTests.java | 7 + 7 files changed, 265 insertions(+), 9 deletions(-) create mode 100644 docs/changelog/122293.yaml create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java diff --git a/docs/changelog/122293.yaml b/docs/changelog/122293.yaml new file mode 100644 index 0000000000000..31e3da771169e --- /dev/null +++ b/docs/changelog/122293.yaml @@ -0,0 +1,5 @@ +pr: 122293 +summary: Add enterprise license check to inference action for semantic text fields +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java new file mode 100644 index 0000000000000..4fc97662166f0 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java @@ -0,0 +1,168 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Strings; +import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.Utils; +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.junit.Before; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; + +public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase { + public static final String INDEX_NAME = "test-index"; + + private final boolean useLegacyFormat; + + public ShardBulkInferenceActionFilterBasicLicenseIT(boolean useLegacyFormat) { + this.useLegacyFormat = useLegacyFormat; + } + + @ParametersFactory + public static Iterable parameters() { + return List.of(new Object[] { true }, new Object[] { false }); + } + + @Before + public void setup() throws Exception { + Utils.storeSparseModel(client()); + Utils.storeDenseModel( + client(), + randomIntBetween(1, 100), + // dot product means that we need normalized vectors; it's not worth doing that in this test + randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values())), + // TODO: Allow element type BIT once TestDenseInferenceServiceExtension supports it + randomValueOtherThan(DenseVectorFieldMapper.ElementType.BIT, () -> randomFrom(DenseVectorFieldMapper.ElementType.values())) + ); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "basic").build(); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateInferencePlugin.class); + } + + @Override + public Settings indexSettings() { + var builder = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10)) + .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat); + return builder.build(); + } + + public void testLicenseInvalidForInference() { + prepareCreate(INDEX_NAME).setMapping( + String.format( + Locale.ROOT, + """ + { + "properties": { + "sparse_field": { + "type": "semantic_text", + "inference_id": "%s" + }, + "dense_field": { + "type": "semantic_text", + "inference_id": "%s" + } + } + } + """, + TestSparseInferenceServiceExtension.TestInferenceService.NAME, + TestDenseInferenceServiceExtension.TestInferenceService.NAME + ) + ).get(); + + BulkRequestBuilder bulkRequest = client().prepareBulk(); + int totalBulkReqs = randomIntBetween(2, 100); + for (int i = 0; i < totalBulkReqs; i++) { + Map source = new HashMap<>(); + source.put("sparse_field", randomSemanticTextInput()); + source.put("dense_field", randomSemanticTextInput()); + + bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source)); + } + + BulkResponse bulkResponse = bulkRequest.get(); + for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) { + assertTrue(bulkItemResponse.isFailed()); + assertThat(bulkItemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class)); + assertThat( + bulkItemResponse.getFailure().getCause().getMessage(), + containsString(Strings.format("current license is non-compliant for [%s]", XPackField.INFERENCE)) + ); + } + } + + public void testNullSourceSucceeds() { + prepareCreate(INDEX_NAME).setMapping( + String.format( + Locale.ROOT, + """ + { + "properties": { + "sparse_field": { + "type": "semantic_text", + "inference_id": "%s" + }, + "dense_field": { + "type": "semantic_text", + "inference_id": "%s" + } + } + } + """, + TestSparseInferenceServiceExtension.TestInferenceService.NAME, + TestDenseInferenceServiceExtension.TestInferenceService.NAME + ) + ).get(); + + BulkRequestBuilder bulkRequest = client().prepareBulk(); + int totalBulkReqs = randomIntBetween(2, 100); + Map source = new HashMap<>(); + source.put("sparse_field", null); + source.put("dense_field", null); + for (int i = 0; i < totalBulkReqs; i++) { + bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source)); + } + + BulkResponse bulkResponse = bulkRequest.get(); + assertFalse(bulkResponse.hasFailures()); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java index b8e594e70abd9..fe3ca7d41ebc8 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java @@ -25,6 +25,7 @@ import org.elasticsearch.index.mapper.SourceFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.ESIntegTestCase; @@ -81,6 +82,11 @@ public void setup() throws Exception { ); } + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + @Override protected Collection> nodePlugins() { return Arrays.asList(LocalStateInferencePlugin.class); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 05b8944138a60..d865b241bb4e0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -312,7 +312,7 @@ public Collection createComponents(PluginServices services) { } inferenceServiceRegistry.set(serviceRegistry); - var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry); + var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState()); shardBulkInferenceActionFilter.set(actionFilter); var meterRegistry = services.telemetryProvider().getMeterRegistry(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 5fca096dae1e3..19942595df45b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -38,9 +38,12 @@ import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; @@ -58,6 +61,8 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; + /** * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified * as {@link SemanticTextFieldMapper} in the index mapping. For each semantic text field referencing fields in @@ -76,25 +81,29 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { private final ClusterService clusterService; private final InferenceServiceRegistry inferenceServiceRegistry; private final ModelRegistry modelRegistry; + private final XPackLicenseState licenseState; private final int batchSize; public ShardBulkInferenceActionFilter( ClusterService clusterService, InferenceServiceRegistry inferenceServiceRegistry, - ModelRegistry modelRegistry + ModelRegistry modelRegistry, + XPackLicenseState licenseState ) { - this(clusterService, inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE); + this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE); } public ShardBulkInferenceActionFilter( ClusterService clusterService, InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, + XPackLicenseState licenseState, int batchSize ) { this.clusterService = clusterService; this.inferenceServiceRegistry = inferenceServiceRegistry; this.modelRegistry = modelRegistry; + this.licenseState = licenseState; this.batchSize = batchSize; } @@ -561,6 +570,11 @@ private Map> createFieldInferenceRequests(Bu break; } + if (INFERENCE_API_FEATURE.check(licenseState) == false) { + addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE)); + break; + } + List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); int offsetAdjustment = 0; for (String v : values) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 1fca17f77ad9a..84c3e5cf80b0c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -9,6 +9,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.bulk.BulkItemRequest; @@ -40,6 +41,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; @@ -47,8 +49,10 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; @@ -113,7 +117,7 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -136,6 +140,44 @@ public void testFilterNoop() throws Exception { awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testLicenseInvalidForInference() throws InterruptedException { + StaticModel model = StaticModel.createRandomInstance(); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false); + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertThat(bulkShardRequest.items().length, equalTo(1)); + + BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure(); + assertNotNull(failure); + assertThat(failure.getCause(), instanceOf(ElasticsearchSecurityException.class)); + assertThat( + failure.getMessage(), + containsString(org.elasticsearch.core.Strings.format("current license is non-compliant for [%s]", XPackField.INFERENCE)) + ); + } finally { + chainExecuted.countDown(); + } + + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + Map inferenceFieldMap = Map.of( + "obj.field1", + new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }) + ); + BulkItemRequest[] items = new BulkItemRequest[1]; + items[0] = new BulkItemRequest(0, new IndexRequest("test").source("obj.field1", "Test")); + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + } + @SuppressWarnings({ "unchecked", "rawtypes" }) public void testInferenceNotFound() throws Exception { StaticModel model = StaticModel.createRandomInstance(); @@ -143,7 +185,8 @@ public void testInferenceNotFound() throws Exception { threadPool, Map.of(model.getInferenceEntityId(), model), randomIntBetween(1, 10), - useLegacyFormat + useLegacyFormat, + true ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { @@ -189,7 +232,8 @@ public void testItemFailures() throws Exception { threadPool, Map.of(model.getInferenceEntityId(), model), randomIntBetween(1, 10), - useLegacyFormat + useLegacyFormat, + true ); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success"))); @@ -255,7 +299,8 @@ public void testExplicitNull() throws Exception { threadPool, Map.of(model.getInferenceEntityId(), model), randomIntBetween(1, 10), - useLegacyFormat + useLegacyFormat, + true ); CountDownLatch chainExecuted = new CountDownLatch(1); @@ -344,7 +389,13 @@ public void testManyRandomDocs() throws Exception { modifiedRequests[id] = res[1]; } - ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, randomIntBetween(10, 30), useLegacyFormat); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + inferenceModelMap, + randomIntBetween(10, 30), + useLegacyFormat, + true + ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -379,7 +430,8 @@ private static ShardBulkInferenceActionFilter createFilter( ThreadPool threadPool, Map modelMap, int batchSize, - boolean useLegacyFormat + boolean useLegacyFormat, + boolean isLicenseValidForInference ) { ModelRegistry modelRegistry = mock(ModelRegistry.class); Answer unparsedModelAnswer = invocationOnMock -> { @@ -437,10 +489,14 @@ private static ShardBulkInferenceActionFilter createFilter( InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); + MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(isLicenseValidForInference); + return new ShardBulkInferenceActionFilter( createClusterService(useLegacyFormat), inferenceServiceRegistry, modelRegistry, + licenseState, batchSize ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java index 24183b21f73e7..d196efa0d152b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java @@ -7,7 +7,9 @@ package org.elasticsearch.xpack.inference.mapper; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.NonDynamicFieldMapperTests; +import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.Utils; @@ -25,6 +27,11 @@ public void setup() throws Exception { Utils.storeSparseModel(client()); } + @Override + protected Settings nodeSettings() { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + @Override protected Collection> getPlugins() { return List.of(LocalStateInferencePlugin.class);