Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/122293.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 122293
summary: Add enterprise license check to inference action for semantic text fields
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.action.filter;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
import org.junit.Before;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
import static org.hamcrest.Matchers.instanceOf;

public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase {
public static final String INDEX_NAME = "test-index";

private final boolean useLegacyFormat;
private final boolean useSyntheticSource;

public ShardBulkInferenceActionFilterBasicLicenseIT(boolean useLegacyFormat, boolean useSyntheticSource) {
this.useLegacyFormat = useLegacyFormat;
this.useSyntheticSource = useSyntheticSource;
}

@ParametersFactory
public static Iterable<Object[]> parameters() throws Exception {
return List.of(
new Object[] { true, false },
new Object[] { true, true },
new Object[] { false, false },
new Object[] { false, true }
);
}

@Before
public void setup() throws Exception {
Utils.storeSparseModel(client());
Utils.storeDenseModel(
client(),
randomIntBetween(1, 100),
// dot product means that we need normalized vectors; it's not worth doing that in this test
randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values())),
// TODO: Allow element type BIT once TestDenseInferenceServiceExtension supports it
randomValueOtherThan(DenseVectorFieldMapper.ElementType.BIT, () -> randomFrom(DenseVectorFieldMapper.ElementType.values()))
);
}

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "basic").build();
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateInferencePlugin.class);
}

@Override
public Settings indexSettings() {
var builder = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10))
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat);
if (useSyntheticSource) {
builder.put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true);
builder.put(IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SourceFieldMapper.Mode.SYNTHETIC.name());
}
return builder.build();
}

public void testLicenseInvalidForInference() {
prepareCreate(INDEX_NAME).setMapping(
String.format(
Locale.ROOT,
"""
{
"properties": {
"sparse_field": {
"type": "semantic_text",
"inference_id": "%s"
},
"dense_field": {
"type": "semantic_text",
"inference_id": "%s"
}
}
}
""",
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
TestDenseInferenceServiceExtension.TestInferenceService.NAME
)
).get();

BulkRequestBuilder bulkRequest = client().prepareBulk();
int totalBulkReqs = randomIntBetween(2, 100);
for (int i = 0; i < totalBulkReqs; i++) {
Map<String, Object> source = new HashMap<>();
source.put("sparse_field", rarely() ? null : randomSemanticTextInput());
source.put("dense_field", rarely() ? null : randomSemanticTextInput());

bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source));
}

BulkResponse bulkResponse = bulkRequest.get();
for (BulkItemResponse itemResponse : bulkResponse) {
assertTrue(itemResponse.isFailed());
assertThat(itemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
Expand Down Expand Up @@ -80,6 +81,11 @@ public void setup() throws Exception {
);
}

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateInferencePlugin.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ public Collection<?> createComponents(PluginServices services) {
}
inferenceServiceRegistry.set(serviceRegistry);

var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry);
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState());
shardBulkInferenceActionFilter.set(actionFilter);

var meterRegistry = services.telemetryProvider().getMeterRegistry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
Expand All @@ -58,6 +61,8 @@
import java.util.Map;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;

/**
* A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
* as {@link SemanticTextFieldMapper} in the index mapping. For each semantic text field referencing fields in
Expand All @@ -76,25 +81,29 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
private final ClusterService clusterService;
private final InferenceServiceRegistry inferenceServiceRegistry;
private final ModelRegistry modelRegistry;
private final XPackLicenseState licenseState;
private final int batchSize;

public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry
ModelRegistry modelRegistry,
XPackLicenseState licenseState
) {
this(clusterService, inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE);
this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE);
}

public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry,
XPackLicenseState licenseState,
int batchSize
) {
this.clusterService = clusterService;
this.inferenceServiceRegistry = inferenceServiceRegistry;
this.modelRegistry = modelRegistry;
this.licenseState = licenseState;
this.batchSize = batchSize;
}

Expand Down Expand Up @@ -561,6 +570,11 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
break;
}

if (INFERENCE_API_FEATURE.check(licenseState) == false) {
addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE));
break;
}

List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
int offsetAdjustment = 0;
for (String v : values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.bulk.BulkItemRequest;
Expand Down Expand Up @@ -40,6 +41,7 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.license.MockLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
Expand All @@ -49,6 +51,7 @@
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.model.TestModel;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
Expand Down Expand Up @@ -113,7 +116,7 @@ public void tearDownThreadPool() throws Exception {

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testFilterNoop() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat);
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
Expand All @@ -136,14 +139,42 @@ public void testFilterNoop() throws Exception {
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testLicenseInvalidForInference() {
StaticModel model = StaticModel.createRandomInstance();
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
assertThat(bulkShardRequest.items().length, equalTo(1));

BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure();
assertNotNull(failure);
assertThat(failure.getCause(), instanceOf(ElasticsearchSecurityException.class));
};
ActionListener actionListener = mock(ActionListener.class);
Task task = mock(Task.class);

Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
"obj.field1",
new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" })
);
BulkItemRequest[] items = new BulkItemRequest[1];
items[0] = new BulkItemRequest(0, new IndexRequest("test").source("obj.field1", "Test"));
BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
request.setInferenceFieldMap(inferenceFieldMap);

filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testInferenceNotFound() throws Exception {
StaticModel model = StaticModel.createRandomInstance();
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat
useLegacyFormat,
true
);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
Expand Down Expand Up @@ -189,7 +220,8 @@ public void testItemFailures() throws Exception {
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat
useLegacyFormat,
true
);
model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom")));
model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success")));
Expand Down Expand Up @@ -255,7 +287,8 @@ public void testExplicitNull() throws Exception {
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat
useLegacyFormat,
true
);

CountDownLatch chainExecuted = new CountDownLatch(1);
Expand Down Expand Up @@ -344,7 +377,13 @@ public void testManyRandomDocs() throws Exception {
modifiedRequests[id] = res[1];
}

ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, randomIntBetween(10, 30), useLegacyFormat);
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
inferenceModelMap,
randomIntBetween(10, 30),
useLegacyFormat,
true
);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
Expand Down Expand Up @@ -379,7 +418,8 @@ private static ShardBulkInferenceActionFilter createFilter(
ThreadPool threadPool,
Map<String, StaticModel> modelMap,
int batchSize,
boolean useLegacyFormat
boolean useLegacyFormat,
boolean isLicenseValidForInference
) {
ModelRegistry modelRegistry = mock(ModelRegistry.class);
Answer<?> unparsedModelAnswer = invocationOnMock -> {
Expand Down Expand Up @@ -437,10 +477,14 @@ private static ShardBulkInferenceActionFilter createFilter(
InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class);
when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService));

MockLicenseState licenseState = MockLicenseState.createMock();
when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(isLicenseValidForInference);

return new ShardBulkInferenceActionFilter(
createClusterService(useLegacyFormat),
inferenceServiceRegistry,
modelRegistry,
licenseState,
batchSize
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

package org.elasticsearch.xpack.inference.mapper;

import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.mapper.NonDynamicFieldMapperTests;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
Expand All @@ -25,6 +27,11 @@ public void setup() throws Exception {
Utils.storeSparseModel(client());
}

@Override
protected Settings nodeSettings() {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
}

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return List.of(LocalStateInferencePlugin.class);
Expand Down