Skip to content

Commit 7604667

Browse files
Move license check to only block semantic_text fields that require inference call
1 parent 5ce7e42 commit 7604667

File tree

3 files changed

+162
-16
lines changed

3 files changed

+162
-16
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.action.filter;
9+
10+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
11+
12+
import org.elasticsearch.ElasticsearchSecurityException;
13+
import org.elasticsearch.action.bulk.BulkItemResponse;
14+
import org.elasticsearch.action.bulk.BulkRequestBuilder;
15+
import org.elasticsearch.action.bulk.BulkResponse;
16+
import org.elasticsearch.action.index.IndexRequestBuilder;
17+
import org.elasticsearch.cluster.metadata.IndexMetadata;
18+
import org.elasticsearch.common.settings.Settings;
19+
import org.elasticsearch.index.IndexSettings;
20+
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
21+
import org.elasticsearch.index.mapper.SourceFieldMapper;
22+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
23+
import org.elasticsearch.inference.SimilarityMeasure;
24+
import org.elasticsearch.license.LicenseSettings;
25+
import org.elasticsearch.plugins.Plugin;
26+
import org.elasticsearch.test.ESIntegTestCase;
27+
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
28+
import org.elasticsearch.xpack.inference.Utils;
29+
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
30+
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
31+
import org.junit.Before;
32+
33+
import java.util.Arrays;
34+
import java.util.Collection;
35+
import java.util.HashMap;
36+
import java.util.List;
37+
import java.util.Locale;
38+
import java.util.Map;
39+
40+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
41+
import static org.hamcrest.Matchers.instanceOf;
42+
43+
public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase {
44+
public static final String INDEX_NAME = "test-index";
45+
46+
private final boolean useLegacyFormat;
47+
private final boolean useSyntheticSource;
48+
49+
public ShardBulkInferenceActionFilterBasicLicenseIT(boolean useLegacyFormat, boolean useSyntheticSource) {
50+
this.useLegacyFormat = useLegacyFormat;
51+
this.useSyntheticSource = useSyntheticSource;
52+
}
53+
54+
@ParametersFactory
55+
public static Iterable<Object[]> parameters() throws Exception {
56+
return List.of(
57+
new Object[] { true, false },
58+
new Object[] { true, true },
59+
new Object[] { false, false },
60+
new Object[] { false, true }
61+
);
62+
}
63+
64+
@Before
65+
public void setup() throws Exception {
66+
Utils.storeSparseModel(client());
67+
Utils.storeDenseModel(
68+
client(),
69+
randomIntBetween(1, 100),
70+
// dot product means that we need normalized vectors; it's not worth doing that in this test
71+
randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values())),
72+
// TODO: Allow element type BIT once TestDenseInferenceServiceExtension supports it
73+
randomValueOtherThan(DenseVectorFieldMapper.ElementType.BIT, () -> randomFrom(DenseVectorFieldMapper.ElementType.values()))
74+
);
75+
}
76+
77+
@Override
78+
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
79+
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "basic").build();
80+
}
81+
82+
@Override
83+
protected Collection<Class<? extends Plugin>> nodePlugins() {
84+
return Arrays.asList(LocalStateInferencePlugin.class);
85+
}
86+
87+
@Override
88+
public Settings indexSettings() {
89+
var builder = Settings.builder()
90+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10))
91+
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat);
92+
if (useSyntheticSource) {
93+
builder.put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true);
94+
builder.put(IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SourceFieldMapper.Mode.SYNTHETIC.name());
95+
}
96+
return builder.build();
97+
}
98+
99+
public void testLicenseInvalidForInference() {
100+
prepareCreate(INDEX_NAME).setMapping(
101+
String.format(
102+
Locale.ROOT,
103+
"""
104+
{
105+
"properties": {
106+
"sparse_field": {
107+
"type": "semantic_text",
108+
"inference_id": "%s"
109+
},
110+
"dense_field": {
111+
"type": "semantic_text",
112+
"inference_id": "%s"
113+
}
114+
}
115+
}
116+
""",
117+
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
118+
TestDenseInferenceServiceExtension.TestInferenceService.NAME
119+
)
120+
).get();
121+
122+
BulkRequestBuilder bulkRequest = client().prepareBulk();
123+
int totalBulkReqs = randomIntBetween(2, 100);
124+
for (int i = 0; i < totalBulkReqs; i++) {
125+
Map<String, Object> source = new HashMap<>();
126+
source.put("sparse_field", rarely() ? null : randomSemanticTextInput());
127+
source.put("dense_field", rarely() ? null : randomSemanticTextInput());
128+
129+
bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source));
130+
}
131+
132+
BulkResponse bulkResponse = bulkRequest.get();
133+
for (BulkItemResponse itemResponse : bulkResponse) {
134+
assertTrue(itemResponse.isFailed());
135+
assertThat(itemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class));
136+
}
137+
}
138+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,6 @@ private AsyncBulkShardInferenceAction(
216216

217217
@Override
218218
public void run() {
219-
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
220-
throw LicenseUtils.newComplianceException(XPackField.INFERENCE);
221-
}
222-
223219
Map<String, List<FieldInferenceRequest>> inferenceRequests = createFieldInferenceRequests(bulkShardRequest);
224220
Runnable onInferenceCompletion = () -> {
225221
try {
@@ -574,6 +570,11 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
574570
break;
575571
}
576572

573+
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
574+
addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE));
575+
break;
576+
}
577+
577578
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
578579
int offsetAdjustment = 0;
579580
for (String v : values) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -141,22 +141,29 @@ public void testFilterNoop() throws Exception {
141141

142142
@SuppressWarnings({ "unchecked", "rawtypes" })
143143
public void testLicenseInvalidForInference() {
144+
StaticModel model = StaticModel.createRandomInstance();
144145
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false);
145-
ActionFilterChain actionFilterChain = mock(ActionFilterChain.class);
146+
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
147+
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
148+
assertThat(bulkShardRequest.items().length, equalTo(1));
149+
150+
BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure();
151+
assertNotNull(failure);
152+
assertThat(failure.getCause(), instanceOf(ElasticsearchSecurityException.class));
153+
};
146154
ActionListener actionListener = mock(ActionListener.class);
147155
Task task = mock(Task.class);
148-
BulkShardRequest request = new BulkShardRequest(
149-
new ShardId("test", "test", 0),
150-
WriteRequest.RefreshPolicy.NONE,
151-
new BulkItemRequest[0]
152-
);
153-
request.setInferenceFieldMap(
154-
Map.of("foo", new InferenceFieldMetadata("foo", "bar", "baz", generateRandomStringArray(5, 10, false, false)))
155-
);
156-
assertThrows(
157-
ElasticsearchSecurityException.class,
158-
() -> filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain)
156+
157+
Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
158+
"obj.field1",
159+
new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" })
159160
);
161+
BulkItemRequest[] items = new BulkItemRequest[1];
162+
items[0] = new BulkItemRequest(0, new IndexRequest("test").source("obj.field1", "Test"));
163+
BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
164+
request.setInferenceFieldMap(inferenceFieldMap);
165+
166+
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
160167
}
161168

162169
@SuppressWarnings({ "unchecked", "rawtypes" })

0 commit comments

Comments
 (0)