Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/122293.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 122293
summary: Add enterprise license check to inference action for semantic text fields
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
Expand Down Expand Up @@ -80,6 +81,11 @@ public void setup() throws Exception {
);
}

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateInferencePlugin.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ public Collection<?> createComponents(PluginServices services) {
}
inferenceServiceRegistry.set(serviceRegistry);

var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry);
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState());
shardBulkInferenceActionFilter.set(actionFilter);

var meterRegistry = services.telemetryProvider().getMeterRegistry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
Expand All @@ -58,6 +61,8 @@
import java.util.Map;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;

/**
* A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
* as {@link SemanticTextFieldMapper} in the index mapping. For each semantic text field referencing fields in
Expand All @@ -76,25 +81,29 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
private final ClusterService clusterService;
private final InferenceServiceRegistry inferenceServiceRegistry;
private final ModelRegistry modelRegistry;
private final XPackLicenseState licenseState;
private final int batchSize;

public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry
ModelRegistry modelRegistry,
XPackLicenseState licenseState
) {
this(clusterService, inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE);
this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE);
}

public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry,
XPackLicenseState licenseState,
int batchSize
) {
this.clusterService = clusterService;
this.inferenceServiceRegistry = inferenceServiceRegistry;
this.modelRegistry = modelRegistry;
this.licenseState = licenseState;
this.batchSize = batchSize;
}

Expand Down Expand Up @@ -207,6 +216,10 @@ private AsyncBulkShardInferenceAction(

@Override
public void run() {
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
throw LicenseUtils.newComplianceException(XPackField.INFERENCE);
}

Map<String, List<FieldInferenceRequest>> inferenceRequests = createFieldInferenceRequests(bulkShardRequest);
Runnable onInferenceCompletion = () -> {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.bulk.BulkItemRequest;
Expand Down Expand Up @@ -40,6 +41,7 @@
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.license.MockLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
Expand All @@ -49,6 +51,7 @@
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.model.TestModel;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
Expand Down Expand Up @@ -113,7 +116,7 @@ public void tearDownThreadPool() throws Exception {

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testFilterNoop() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat);
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
Expand All @@ -136,14 +139,35 @@ public void testFilterNoop() throws Exception {
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testLicenseInvalidForInference() {
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false);
ActionFilterChain actionFilterChain = mock(ActionFilterChain.class);
ActionListener actionListener = mock(ActionListener.class);
Task task = mock(Task.class);
BulkShardRequest request = new BulkShardRequest(
new ShardId("test", "test", 0),
WriteRequest.RefreshPolicy.NONE,
new BulkItemRequest[0]
);
request.setInferenceFieldMap(
Map.of("foo", new InferenceFieldMetadata("foo", "bar", "baz", generateRandomStringArray(5, 10, false, false)))
);
assertThrows(
ElasticsearchSecurityException.class,
() -> filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain)
);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testInferenceNotFound() throws Exception {
StaticModel model = StaticModel.createRandomInstance();
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat
useLegacyFormat,
true
);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
Expand Down Expand Up @@ -189,7 +213,8 @@ public void testItemFailures() throws Exception {
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat
useLegacyFormat,
true
);
model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom")));
model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success")));
Expand Down Expand Up @@ -255,7 +280,8 @@ public void testExplicitNull() throws Exception {
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat
useLegacyFormat,
true
);

CountDownLatch chainExecuted = new CountDownLatch(1);
Expand Down Expand Up @@ -344,7 +370,13 @@ public void testManyRandomDocs() throws Exception {
modifiedRequests[id] = res[1];
}

ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, randomIntBetween(10, 30), useLegacyFormat);
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
inferenceModelMap,
randomIntBetween(10, 30),
useLegacyFormat,
true
);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
Expand Down Expand Up @@ -379,7 +411,8 @@ private static ShardBulkInferenceActionFilter createFilter(
ThreadPool threadPool,
Map<String, StaticModel> modelMap,
int batchSize,
boolean useLegacyFormat
boolean useLegacyFormat,
boolean isLicenseValidForInference
) {
ModelRegistry modelRegistry = mock(ModelRegistry.class);
Answer<?> unparsedModelAnswer = invocationOnMock -> {
Expand Down Expand Up @@ -437,10 +470,14 @@ private static ShardBulkInferenceActionFilter createFilter(
InferenceServiceRegistry inferenceServiceRegistry = mock(InferenceServiceRegistry.class);
when(inferenceServiceRegistry.getService(any())).thenReturn(Optional.of(inferenceService));

MockLicenseState licenseState = MockLicenseState.createMock();
when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(isLicenseValidForInference);

return new ShardBulkInferenceActionFilter(
createClusterService(useLegacyFormat),
inferenceServiceRegistry,
modelRegistry,
licenseState,
batchSize
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

package org.elasticsearch.xpack.inference.mapper;

import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.mapper.NonDynamicFieldMapperTests;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
Expand All @@ -25,6 +27,11 @@ public void setup() throws Exception {
Utils.storeSparseModel(client());
}

@Override
protected Settings nodeSettings() {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
}

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return List.of(LocalStateInferencePlugin.class);
Expand Down