From 31af555c4a582eb5fc15ad09c7a6ada5791d0ed2 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Tue, 11 Feb 2025 13:28:44 -0500 Subject: [PATCH 1/6] Add enterprise license check to inference action for semantic text fields --- .../xpack/inference/InferencePlugin.java | 2 +- .../ShardBulkInferenceActionFilter.java | 17 ++++++- .../ShardBulkInferenceActionFilterTests.java | 49 ++++++++++++++++--- ...emanticTextNonDynamicFieldMapperTests.java | 7 +++ 4 files changed, 66 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index e3604351c1937..db57a4204631d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -310,7 +310,7 @@ public Collection createComponents(PluginServices services) { } inferenceServiceRegistry.set(serviceRegistry); - var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry); + var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState()); shardBulkInferenceActionFilter.set(actionFilter); var meterRegistry = services.telemetryProvider().getMeterRegistry(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 3933260664b7c..79b55fccd33a0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -38,9 +38,12 @@ import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.xcontent.XContent; +import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; @@ -58,6 +61,8 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE; + /** * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified * as {@link SemanticTextFieldMapper} in the index mapping. For each semantic text field referencing fields in @@ -76,25 +81,29 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { private final ClusterService clusterService; private final InferenceServiceRegistry inferenceServiceRegistry; private final ModelRegistry modelRegistry; + private final XPackLicenseState licenseState; private final int batchSize; public ShardBulkInferenceActionFilter( ClusterService clusterService, InferenceServiceRegistry inferenceServiceRegistry, - ModelRegistry modelRegistry + ModelRegistry modelRegistry, + XPackLicenseState licenseState ) { - this(clusterService, inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE); + this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE); } public ShardBulkInferenceActionFilter( ClusterService clusterService, InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, + XPackLicenseState licenseState, int batchSize ) { this.clusterService = clusterService; this.inferenceServiceRegistry = inferenceServiceRegistry; this.modelRegistry = modelRegistry; + this.licenseState = licenseState; this.batchSize = batchSize; } @@ -207,6 +216,10 @@ private AsyncBulkShardInferenceAction( @Override public void run() { + if (INFERENCE_API_FEATURE.check(licenseState) == false) { + throw LicenseUtils.newComplianceException(XPackField.INFERENCE); + } + Map> inferenceRequests = createFieldInferenceRequests(bulkShardRequest); Runnable onInferenceCompletion = () -> { try { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 1fca17f77ad9a..74b6e79eb7106 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -9,6 +9,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.bulk.BulkItemRequest; @@ -40,6 +41,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.license.MockLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; @@ -49,6 +51,7 @@ import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.model.TestModel; import org.elasticsearch.xpack.inference.registry.ModelRegistry; @@ -113,7 +116,7 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -136,6 +139,26 @@ public void testFilterNoop() throws Exception { awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testLicenseInvalidForInference() { + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false); + ActionFilterChain actionFilterChain = mock(ActionFilterChain.class); + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + BulkShardRequest request = new BulkShardRequest( + new ShardId("test", "test", 0), + WriteRequest.RefreshPolicy.NONE, + new BulkItemRequest[0] + ); + request.setInferenceFieldMap( + Map.of("foo", new InferenceFieldMetadata("foo", "bar", "baz", generateRandomStringArray(5, 10, false, false))) + ); + assertThrows( + ElasticsearchSecurityException.class, + () -> filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain) + ); + } + @SuppressWarnings({ "unchecked", "rawtypes" }) public void testInferenceNotFound() throws Exception { StaticModel model = StaticModel.createRandomInstance(); @@ -143,7 +166,8 @@ public void testInferenceNotFound() throws Exception { threadPool, Map.of(model.getInferenceEntityId(), model), randomIntBetween(1, 10), - useLegacyFormat + useLegacyFormat, + true ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { @@ -189,7 +213,8 @@ public void testItemFailures() throws Exception { threadPool, Map.of(model.getInferenceEntityId(), model), randomIntBetween(1, 10), - useLegacyFormat + useLegacyFormat, + true ); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success"))); @@ -255,7 +280,8 @@ public void testExplicitNull() throws Exception { threadPool, Map.of(model.getInferenceEntityId(), model), randomIntBetween(1, 10), - useLegacyFormat + useLegacyFormat, + true ); CountDownLatch chainExecuted = new CountDownLatch(1); @@ -344,7 +370,13 @@ public void testManyRandomDocs() throws Exception { modifiedRequests[id] = res[1]; } - ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, randomIntBetween(10, 30), useLegacyFormat); + ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + inferenceModelMap, + randomIntBetween(10, 30), + useLegacyFormat, + true + ); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -379,7 +411,8 @@ private static ShardBulkInferenceActionFilter createFilter( ThreadPool threadPool, Map modelMap, int batchSize, - boolean useLegacyFormat + boolean useLegacyFormat, + boolean isLicenseValidForInference ) { ModelRegistry modelRegistry = mock(ModelRegistry.class); Answer unparsedModelAnswer = invocationOnMock -> { @@ -437,10 +470,14 @@ private static ShardBulkInferenceActionFilter createFilter( InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class); when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService)); + MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(isLicenseValidForInference); + return new ShardBulkInferenceActionFilter( createClusterService(useLegacyFormat), inferenceServiceRegistry, modelRegistry, + licenseState, batchSize ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java index 24183b21f73e7..d196efa0d152b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextNonDynamicFieldMapperTests.java @@ -7,7 +7,9 @@ package org.elasticsearch.xpack.inference.mapper; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.mapper.NonDynamicFieldMapperTests; +import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.Utils; @@ -25,6 +27,11 @@ public void setup() throws Exception { Utils.storeSparseModel(client()); } + @Override + protected Settings nodeSettings() { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + @Override protected Collection> getPlugins() { return List.of(LocalStateInferencePlugin.class); From b20723121957e2a72b46d347dd3abbb62de72f4e Mon Sep 17 00:00:00 2001 From: Dan Rubinstein Date: Tue, 11 Feb 2025 14:37:51 -0500 Subject: [PATCH 2/6] Update docs/changelog/122293.yaml --- docs/changelog/122293.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/122293.yaml diff --git a/docs/changelog/122293.yaml b/docs/changelog/122293.yaml new file mode 100644 index 0000000000000..31e3da771169e --- /dev/null +++ b/docs/changelog/122293.yaml @@ -0,0 +1,5 @@ +pr: 122293 +summary: Add enterprise license check to inference action for semantic text fields +area: Machine Learning +type: bug +issues: [] From 5ce7e42b4b8d10a9f6b089d21bee9166211a610c Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Tue, 11 Feb 2025 16:36:47 -0500 Subject: [PATCH 3/6] Set license to trial in ShardBulkInferenceActionFilterIT --- .../action/filter/ShardBulkInferenceActionFilterIT.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java index 303f957c7ab20..63910908893c2 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java @@ -25,6 +25,7 @@ import org.elasticsearch.index.mapper.SourceFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.ESIntegTestCase; @@ -80,6 +81,11 @@ public void setup() throws Exception { ); } + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build(); + } + @Override protected Collection> nodePlugins() { return Arrays.asList(LocalStateInferencePlugin.class); From 760466763b9094720011343db5087f621035a6aa Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Wed, 12 Feb 2025 14:36:17 -0500 Subject: [PATCH 4/6] Move license check to only block semantic_text fields that require inference call --- ...lkInferenceActionFilterBasicLicenseIT.java | 138 ++++++++++++++++++ .../ShardBulkInferenceActionFilter.java | 9 +- .../ShardBulkInferenceActionFilterTests.java | 31 ++-- 3 files changed, 162 insertions(+), 16 deletions(-) create mode 100644 x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java new file mode 100644 index 0000000000000..7b09eb602256b --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java @@ -0,0 +1,138 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action.filter; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; +import org.elasticsearch.index.mapper.SourceFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.Utils; +import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; +import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension; +import org.junit.Before; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; +import static org.hamcrest.Matchers.instanceOf; + +public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase { + public static final String INDEX_NAME = "test-index"; + + private final boolean useLegacyFormat; + private final boolean useSyntheticSource; + + public ShardBulkInferenceActionFilterBasicLicenseIT(boolean useLegacyFormat, boolean useSyntheticSource) { + this.useLegacyFormat = useLegacyFormat; + this.useSyntheticSource = useSyntheticSource; + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return List.of( + new Object[] { true, false }, + new Object[] { true, true }, + new Object[] { false, false }, + new Object[] { false, true } + ); + } + + @Before + public void setup() throws Exception { + Utils.storeSparseModel(client()); + Utils.storeDenseModel( + client(), + randomIntBetween(1, 100), + // dot product means that we need normalized vectors; it's not worth doing that in this test + randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values())), + // TODO: Allow element type BIT once TestDenseInferenceServiceExtension supports it + randomValueOtherThan(DenseVectorFieldMapper.ElementType.BIT, () -> randomFrom(DenseVectorFieldMapper.ElementType.values())) + ); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "basic").build(); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateInferencePlugin.class); + } + + @Override + public Settings indexSettings() { + var builder = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10)) + .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat); + if (useSyntheticSource) { + builder.put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true); + builder.put(IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SourceFieldMapper.Mode.SYNTHETIC.name()); + } + return builder.build(); + } + + public void testLicenseInvalidForInference() { + prepareCreate(INDEX_NAME).setMapping( + String.format( + Locale.ROOT, + """ + { + "properties": { + "sparse_field": { + "type": "semantic_text", + "inference_id": "%s" + }, + "dense_field": { + "type": "semantic_text", + "inference_id": "%s" + } + } + } + """, + TestSparseInferenceServiceExtension.TestInferenceService.NAME, + TestDenseInferenceServiceExtension.TestInferenceService.NAME + ) + ).get(); + + BulkRequestBuilder bulkRequest = client().prepareBulk(); + int totalBulkReqs = randomIntBetween(2, 100); + for (int i = 0; i < totalBulkReqs; i++) { + Map source = new HashMap<>(); + source.put("sparse_field", rarely() ? null : randomSemanticTextInput()); + source.put("dense_field", rarely() ? null : randomSemanticTextInput()); + + bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source)); + } + + BulkResponse bulkResponse = bulkRequest.get(); + for (BulkItemResponse itemResponse : bulkResponse) { + assertTrue(itemResponse.isFailed()); + assertThat(itemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 79b55fccd33a0..d9d9727b3c79a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -216,10 +216,6 @@ private AsyncBulkShardInferenceAction( @Override public void run() { - if (INFERENCE_API_FEATURE.check(licenseState) == false) { - throw LicenseUtils.newComplianceException(XPackField.INFERENCE); - } - Map> inferenceRequests = createFieldInferenceRequests(bulkShardRequest); Runnable onInferenceCompletion = () -> { try { @@ -574,6 +570,11 @@ private Map> createFieldInferenceRequests(Bu break; } + if (INFERENCE_API_FEATURE.check(licenseState) == false) { + addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE)); + break; + } + List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); int offsetAdjustment = 0; for (String v : values) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 74b6e79eb7106..43886803997f1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -141,22 +141,29 @@ public void testFilterNoop() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testLicenseInvalidForInference() { + StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false); - ActionFilterChain actionFilterChain = mock(ActionFilterChain.class); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertThat(bulkShardRequest.items().length, equalTo(1)); + + BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure(); + assertNotNull(failure); + assertThat(failure.getCause(), instanceOf(ElasticsearchSecurityException.class)); + }; ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); - BulkShardRequest request = new BulkShardRequest( - new ShardId("test", "test", 0), - WriteRequest.RefreshPolicy.NONE, - new BulkItemRequest[0] - ); - request.setInferenceFieldMap( - Map.of("foo", new InferenceFieldMetadata("foo", "bar", "baz", generateRandomStringArray(5, 10, false, false))) - ); - assertThrows( - ElasticsearchSecurityException.class, - () -> filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain) + + Map inferenceFieldMap = Map.of( + "obj.field1", + new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }) ); + BulkItemRequest[] items = new BulkItemRequest[1]; + items[0] = new BulkItemRequest(0, new IndexRequest("test").source("obj.field1", "Test")); + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); } @SuppressWarnings({ "unchecked", "rawtypes" }) From af3f1e951c6c1f760ef3f229808ce97450261373 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Thu, 13 Feb 2025 10:24:38 -0500 Subject: [PATCH 5/6] Cleaning up tests --- ...lkInferenceActionFilterBasicLicenseIT.java | 85 +++++++++++-------- .../ShardBulkInferenceActionFilterTests.java | 24 ++++-- 2 files changed, 67 insertions(+), 42 deletions(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java index 7b09eb602256b..c66cae2face51 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.inference.action.filter; -import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; - import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; @@ -16,14 +14,13 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.index.IndexSettings; -import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; -import org.elasticsearch.index.mapper.SourceFieldMapper; +import org.elasticsearch.core.Strings; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.license.LicenseSettings; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension; @@ -33,34 +30,16 @@ import java.util.Arrays; import java.util.Collection; import java.util.HashMap; -import java.util.List; import java.util.Locale; import java.util.Map; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase { public static final String INDEX_NAME = "test-index"; - private final boolean useLegacyFormat; - private final boolean useSyntheticSource; - - public ShardBulkInferenceActionFilterBasicLicenseIT(boolean useLegacyFormat, boolean useSyntheticSource) { - this.useLegacyFormat = useLegacyFormat; - this.useSyntheticSource = useSyntheticSource; - } - - @ParametersFactory - public static Iterable parameters() throws Exception { - return List.of( - new Object[] { true, false }, - new Object[] { true, true }, - new Object[] { false, false }, - new Object[] { false, true } - ); - } - @Before public void setup() throws Exception { Utils.storeSparseModel(client()); @@ -86,13 +65,7 @@ protected Collection> nodePlugins() { @Override public Settings indexSettings() { - var builder = Settings.builder() - .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10)) - .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat); - if (useSyntheticSource) { - builder.put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true); - builder.put(IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SourceFieldMapper.Mode.SYNTHETIC.name()); - } + var builder = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10)); return builder.build(); } @@ -123,16 +96,56 @@ public void testLicenseInvalidForInference() { int totalBulkReqs = randomIntBetween(2, 100); for (int i = 0; i < totalBulkReqs; i++) { Map source = new HashMap<>(); - source.put("sparse_field", rarely() ? null : randomSemanticTextInput()); - source.put("dense_field", rarely() ? null : randomSemanticTextInput()); + source.put("sparse_field", randomSemanticTextInput()); + source.put("dense_field", randomSemanticTextInput()); bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source)); } BulkResponse bulkResponse = bulkRequest.get(); - for (BulkItemResponse itemResponse : bulkResponse) { - assertTrue(itemResponse.isFailed()); - assertThat(itemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class)); + for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) { + assertTrue(bulkItemResponse.isFailed()); + assertThat(bulkItemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class)); + assertThat( + bulkItemResponse.getFailure().getCause().getMessage(), + containsString(Strings.format("current license is non-compliant for [%s]", XPackField.INFERENCE)) + ); } } + + public void testNullSourceSucceeds() { + prepareCreate(INDEX_NAME).setMapping( + String.format( + Locale.ROOT, + """ + { + "properties": { + "sparse_field": { + "type": "semantic_text", + "inference_id": "%s" + }, + "dense_field": { + "type": "semantic_text", + "inference_id": "%s" + } + } + } + """, + TestSparseInferenceServiceExtension.TestInferenceService.NAME, + TestDenseInferenceServiceExtension.TestInferenceService.NAME + ) + ).get(); + + BulkRequestBuilder bulkRequest = client().prepareBulk(); + int totalBulkReqs = randomIntBetween(2, 100); + Map source = new HashMap<>(); + source.put("sparse_field", null); + source.put("dense_field", null); + for (int i = 0; i < totalBulkReqs; i++) { + bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source)); + } + + BulkResponse bulkResponse = bulkRequest.get(); + assertFalse(bulkResponse.hasFailures()); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 43886803997f1..84c3e5cf80b0c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -49,6 +49,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.InferencePlugin; @@ -140,16 +141,26 @@ public void testFilterNoop() throws Exception { } @SuppressWarnings({ "unchecked", "rawtypes" }) - public void testLicenseInvalidForInference() { + public void testLicenseInvalidForInference() throws InterruptedException { StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false); + CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { - BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertThat(bulkShardRequest.items().length, equalTo(1)); + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertThat(bulkShardRequest.items().length, equalTo(1)); + + BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure(); + assertNotNull(failure); + assertThat(failure.getCause(), instanceOf(ElasticsearchSecurityException.class)); + assertThat( + failure.getMessage(), + containsString(org.elasticsearch.core.Strings.format("current license is non-compliant for [%s]", XPackField.INFERENCE)) + ); + } finally { + chainExecuted.countDown(); + } - BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure(); - assertNotNull(failure); - assertThat(failure.getCause(), instanceOf(ElasticsearchSecurityException.class)); }; ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); @@ -164,6 +175,7 @@ public void testLicenseInvalidForInference() { request.setInferenceFieldMap(inferenceFieldMap); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } @SuppressWarnings({ "unchecked", "rawtypes" }) From 0473574ec7db90762bfc91d69df3dfd055cbef34 Mon Sep 17 00:00:00 2001 From: dan-rubinstein Date: Thu, 13 Feb 2025 14:20:15 -0500 Subject: [PATCH 6/6] Add parameterization on useLegacyFormat back in ShardBulkInferenceActionFilterBasicLicenseIT --- ...lkInferenceActionFilterBasicLicenseIT.java | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java index c66cae2face51..4fc97662166f0 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.action.filter; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; @@ -15,6 +17,7 @@ import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Strings; +import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.license.LicenseSettings; @@ -30,6 +33,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.HashMap; +import java.util.List; import java.util.Locale; import java.util.Map; @@ -40,6 +44,17 @@ public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase { public static final String INDEX_NAME = "test-index"; + private final boolean useLegacyFormat; + + public ShardBulkInferenceActionFilterBasicLicenseIT(boolean useLegacyFormat) { + this.useLegacyFormat = useLegacyFormat; + } + + @ParametersFactory + public static Iterable parameters() { + return List.of(new Object[] { true }, new Object[] { false }); + } + @Before public void setup() throws Exception { Utils.storeSparseModel(client()); @@ -65,7 +80,9 @@ protected Collection> nodePlugins() { @Override public Settings indexSettings() { - var builder = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10)); + var builder = Settings.builder() + .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10)) + .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat); return builder.build(); }