Skip to content

Commit 2ee3027

Browse files
Add enterprise license check to inference action for semantic text fields (#122293) (#123065)
* Add enterprise license check to inference action for semantic text fields * Update docs/changelog/122293.yaml * Set license to trial in ShardBulkInferenceActionFilterIT * Move license check to only block semantic_text fields that require inference call * Cleaning up tests * Add parameterization on useLegacyFormat back in ShardBulkInferenceActionFilterBasicLicenseIT --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent 610e4ac commit 2ee3027

File tree

7 files changed

+265
-9
lines changed

7 files changed

+265
-9
lines changed

docs/changelog/122293.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 122293
2+
summary: Add enterprise license check to inference action for semantic text fields
3+
area: Machine Learning
4+
type: bug
5+
issues: []
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.action.filter;
9+
10+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
11+
12+
import org.elasticsearch.ElasticsearchSecurityException;
13+
import org.elasticsearch.action.bulk.BulkItemResponse;
14+
import org.elasticsearch.action.bulk.BulkRequestBuilder;
15+
import org.elasticsearch.action.bulk.BulkResponse;
16+
import org.elasticsearch.action.index.IndexRequestBuilder;
17+
import org.elasticsearch.cluster.metadata.IndexMetadata;
18+
import org.elasticsearch.common.settings.Settings;
19+
import org.elasticsearch.core.Strings;
20+
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
21+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
22+
import org.elasticsearch.inference.SimilarityMeasure;
23+
import org.elasticsearch.license.LicenseSettings;
24+
import org.elasticsearch.plugins.Plugin;
25+
import org.elasticsearch.test.ESIntegTestCase;
26+
import org.elasticsearch.xpack.core.XPackField;
27+
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
28+
import org.elasticsearch.xpack.inference.Utils;
29+
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
30+
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
31+
import org.junit.Before;
32+
33+
import java.util.Arrays;
34+
import java.util.Collection;
35+
import java.util.HashMap;
36+
import java.util.List;
37+
import java.util.Locale;
38+
import java.util.Map;
39+
40+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
41+
import static org.hamcrest.Matchers.containsString;
42+
import static org.hamcrest.Matchers.instanceOf;
43+
44+
public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase {
45+
public static final String INDEX_NAME = "test-index";
46+
47+
private final boolean useLegacyFormat;
48+
49+
public ShardBulkInferenceActionFilterBasicLicenseIT(boolean useLegacyFormat) {
50+
this.useLegacyFormat = useLegacyFormat;
51+
}
52+
53+
@ParametersFactory
54+
public static Iterable<Object[]> parameters() {
55+
return List.of(new Object[] { true }, new Object[] { false });
56+
}
57+
58+
@Before
59+
public void setup() throws Exception {
60+
Utils.storeSparseModel(client());
61+
Utils.storeDenseModel(
62+
client(),
63+
randomIntBetween(1, 100),
64+
// dot product means that we need normalized vectors; it's not worth doing that in this test
65+
randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values())),
66+
// TODO: Allow element type BIT once TestDenseInferenceServiceExtension supports it
67+
randomValueOtherThan(DenseVectorFieldMapper.ElementType.BIT, () -> randomFrom(DenseVectorFieldMapper.ElementType.values()))
68+
);
69+
}
70+
71+
@Override
72+
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
73+
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "basic").build();
74+
}
75+
76+
@Override
77+
protected Collection<Class<? extends Plugin>> nodePlugins() {
78+
return Arrays.asList(LocalStateInferencePlugin.class);
79+
}
80+
81+
@Override
82+
public Settings indexSettings() {
83+
var builder = Settings.builder()
84+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10))
85+
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat);
86+
return builder.build();
87+
}
88+
89+
public void testLicenseInvalidForInference() {
90+
prepareCreate(INDEX_NAME).setMapping(
91+
String.format(
92+
Locale.ROOT,
93+
"""
94+
{
95+
"properties": {
96+
"sparse_field": {
97+
"type": "semantic_text",
98+
"inference_id": "%s"
99+
},
100+
"dense_field": {
101+
"type": "semantic_text",
102+
"inference_id": "%s"
103+
}
104+
}
105+
}
106+
""",
107+
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
108+
TestDenseInferenceServiceExtension.TestInferenceService.NAME
109+
)
110+
).get();
111+
112+
BulkRequestBuilder bulkRequest = client().prepareBulk();
113+
int totalBulkReqs = randomIntBetween(2, 100);
114+
for (int i = 0; i < totalBulkReqs; i++) {
115+
Map<String, Object> source = new HashMap<>();
116+
source.put("sparse_field", randomSemanticTextInput());
117+
source.put("dense_field", randomSemanticTextInput());
118+
119+
bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source));
120+
}
121+
122+
BulkResponse bulkResponse = bulkRequest.get();
123+
for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) {
124+
assertTrue(bulkItemResponse.isFailed());
125+
assertThat(bulkItemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class));
126+
assertThat(
127+
bulkItemResponse.getFailure().getCause().getMessage(),
128+
containsString(Strings.format("current license is non-compliant for [%s]", XPackField.INFERENCE))
129+
);
130+
}
131+
}
132+
133+
public void testNullSourceSucceeds() {
134+
prepareCreate(INDEX_NAME).setMapping(
135+
String.format(
136+
Locale.ROOT,
137+
"""
138+
{
139+
"properties": {
140+
"sparse_field": {
141+
"type": "semantic_text",
142+
"inference_id": "%s"
143+
},
144+
"dense_field": {
145+
"type": "semantic_text",
146+
"inference_id": "%s"
147+
}
148+
}
149+
}
150+
""",
151+
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
152+
TestDenseInferenceServiceExtension.TestInferenceService.NAME
153+
)
154+
).get();
155+
156+
BulkRequestBuilder bulkRequest = client().prepareBulk();
157+
int totalBulkReqs = randomIntBetween(2, 100);
158+
Map<String, Object> source = new HashMap<>();
159+
source.put("sparse_field", null);
160+
source.put("dense_field", null);
161+
for (int i = 0; i < totalBulkReqs; i++) {
162+
bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source));
163+
}
164+
165+
BulkResponse bulkResponse = bulkRequest.get();
166+
assertFalse(bulkResponse.hasFailures());
167+
}
168+
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.index.mapper.SourceFieldMapper;
2626
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
2727
import org.elasticsearch.inference.SimilarityMeasure;
28+
import org.elasticsearch.license.LicenseSettings;
2829
import org.elasticsearch.plugins.Plugin;
2930
import org.elasticsearch.search.builder.SearchSourceBuilder;
3031
import org.elasticsearch.test.ESIntegTestCase;
@@ -81,6 +82,11 @@ public void setup() throws Exception {
8182
);
8283
}
8384

85+
@Override
86+
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
87+
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
88+
}
89+
8490
@Override
8591
protected Collection<Class<? extends Plugin>> nodePlugins() {
8692
return Arrays.asList(LocalStateInferencePlugin.class);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ public Collection<?> createComponents(PluginServices services) {
312312
}
313313
inferenceServiceRegistry.set(serviceRegistry);
314314

315-
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry);
315+
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState());
316316
shardBulkInferenceActionFilter.set(actionFilter);
317317

318318
var meterRegistry = services.telemetryProvider().getMeterRegistry();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,12 @@
3838
import org.elasticsearch.inference.MinimalServiceSettings;
3939
import org.elasticsearch.inference.Model;
4040
import org.elasticsearch.inference.UnparsedModel;
41+
import org.elasticsearch.license.LicenseUtils;
42+
import org.elasticsearch.license.XPackLicenseState;
4143
import org.elasticsearch.rest.RestStatus;
4244
import org.elasticsearch.tasks.Task;
4345
import org.elasticsearch.xcontent.XContent;
46+
import org.elasticsearch.xpack.core.XPackField;
4447
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
4548
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
4649
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
@@ -58,6 +61,8 @@
5861
import java.util.Map;
5962
import java.util.stream.Collectors;
6063

64+
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
65+
6166
/**
6267
* A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
6368
* as {@link SemanticTextFieldMapper} in the index mapping. For each semantic text field referencing fields in
@@ -76,25 +81,29 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
7681
private final ClusterService clusterService;
7782
private final InferenceServiceRegistry inferenceServiceRegistry;
7883
private final ModelRegistry modelRegistry;
84+
private final XPackLicenseState licenseState;
7985
private final int batchSize;
8086

8187
public ShardBulkInferenceActionFilter(
8288
ClusterService clusterService,
8389
InferenceServiceRegistry inferenceServiceRegistry,
84-
ModelRegistry modelRegistry
90+
ModelRegistry modelRegistry,
91+
XPackLicenseState licenseState
8592
) {
86-
this(clusterService, inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE);
93+
this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE);
8794
}
8895

8996
public ShardBulkInferenceActionFilter(
9097
ClusterService clusterService,
9198
InferenceServiceRegistry inferenceServiceRegistry,
9299
ModelRegistry modelRegistry,
100+
XPackLicenseState licenseState,
93101
int batchSize
94102
) {
95103
this.clusterService = clusterService;
96104
this.inferenceServiceRegistry = inferenceServiceRegistry;
97105
this.modelRegistry = modelRegistry;
106+
this.licenseState = licenseState;
98107
this.batchSize = batchSize;
99108
}
100109

@@ -561,6 +570,11 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
561570
break;
562571
}
563572

573+
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
574+
addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE));
575+
break;
576+
}
577+
564578
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
565579
int offsetAdjustment = 0;
566580
for (String v : values) {

0 commit comments

Comments
 (0)