Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/122293.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 122293
summary: Add enterprise license check to inference action for semantic text fields
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.action.filter;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Strings;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
import org.junit.Before;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;

public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase {
public static final String INDEX_NAME = "test-index";

private final boolean useLegacyFormat;

public ShardBulkInferenceActionFilterBasicLicenseIT(boolean useLegacyFormat) {
this.useLegacyFormat = useLegacyFormat;
}

@ParametersFactory
public static Iterable<Object[]> parameters() {
return List.of(new Object[] { true }, new Object[] { false });
}

@Before
public void setup() throws Exception {
Utils.storeSparseModel(client());
Utils.storeDenseModel(
client(),
randomIntBetween(1, 100),
// dot product means that we need normalized vectors; it's not worth doing that in this test
randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values())),
// TODO: Allow element type BIT once TestDenseInferenceServiceExtension supports it
randomValueOtherThan(DenseVectorFieldMapper.ElementType.BIT, () -> randomFrom(DenseVectorFieldMapper.ElementType.values()))
);
}

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "basic").build();
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateInferencePlugin.class);
}

@Override
public Settings indexSettings() {
var builder = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10))
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat);
return builder.build();
}

public void testLicenseInvalidForInference() {
prepareCreate(INDEX_NAME).setMapping(
String.format(
Locale.ROOT,
"""
{
"properties": {
"sparse_field": {
"type": "semantic_text",
"inference_id": "%s"
},
"dense_field": {
"type": "semantic_text",
"inference_id": "%s"
}
}
}
""",
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
TestDenseInferenceServiceExtension.TestInferenceService.NAME
)
).get();

BulkRequestBuilder bulkRequest = client().prepareBulk();
int totalBulkReqs = randomIntBetween(2, 100);
for (int i = 0; i < totalBulkReqs; i++) {
Map<String, Object> source = new HashMap<>();
source.put("sparse_field", randomSemanticTextInput());
source.put("dense_field", randomSemanticTextInput());

bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source));
}

BulkResponse bulkResponse = bulkRequest.get();
for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) {
assertTrue(bulkItemResponse.isFailed());
assertThat(bulkItemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class));
assertThat(
bulkItemResponse.getFailure().getCause().getMessage(),
containsString(Strings.format("current license is non-compliant for [%s]", XPackField.INFERENCE))
);
}
}

public void testNullSourceSucceeds() {
prepareCreate(INDEX_NAME).setMapping(
String.format(
Locale.ROOT,
"""
{
"properties": {
"sparse_field": {
"type": "semantic_text",
"inference_id": "%s"
},
"dense_field": {
"type": "semantic_text",
"inference_id": "%s"
}
}
}
""",
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
TestDenseInferenceServiceExtension.TestInferenceService.NAME
)
).get();

BulkRequestBuilder bulkRequest = client().prepareBulk();
int totalBulkReqs = randomIntBetween(2, 100);
Map<String, Object> source = new HashMap<>();
source.put("sparse_field", null);
source.put("dense_field", null);
for (int i = 0; i < totalBulkReqs; i++) {
bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source));
}

BulkResponse bulkResponse = bulkRequest.get();
assertFalse(bulkResponse.hasFailures());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
Expand Down Expand Up @@ -81,6 +82,11 @@ public void setup() throws Exception {
);
}

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateInferencePlugin.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ public Collection<?> createComponents(PluginServices services) {
}
inferenceServiceRegistry.set(serviceRegistry);

var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry);
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState());
shardBulkInferenceActionFilter.set(actionFilter);

var meterRegistry = services.telemetryProvider().getMeterRegistry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
Expand All @@ -58,6 +61,8 @@
import java.util.Map;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;

/**
* A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
* as {@link SemanticTextFieldMapper} in the index mapping. For each semantic text field referencing fields in
Expand All @@ -76,25 +81,29 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
private final ClusterService clusterService;
private final InferenceServiceRegistry inferenceServiceRegistry;
private final ModelRegistry modelRegistry;
private final XPackLicenseState licenseState;
private final int batchSize;

public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry
ModelRegistry modelRegistry,
XPackLicenseState licenseState
) {
this(clusterService, inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE);
this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE);
}

public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry,
XPackLicenseState licenseState,
int batchSize
) {
this.clusterService = clusterService;
this.inferenceServiceRegistry = inferenceServiceRegistry;
this.modelRegistry = modelRegistry;
this.licenseState = licenseState;
this.batchSize = batchSize;
}

Expand Down Expand Up @@ -561,6 +570,11 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
break;
}

if (INFERENCE_API_FEATURE.check(licenseState) == false) {
addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE));
break;
}

List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
int offsetAdjustment = 0;
for (String v : values) {
Expand Down
Loading