Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/122293.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 122293
summary: Add enterprise license check to inference action for semantic text fields
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.action.filter;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Strings;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
import org.junit.Before;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;

public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase {
public static final String INDEX_NAME = "test-index";

private final boolean useLegacyFormat;

public ShardBulkInferenceActionFilterBasicLicenseIT(boolean useLegacyFormat) {
this.useLegacyFormat = useLegacyFormat;
}

@ParametersFactory
public static Iterable<Object[]> parameters() {
return List.of(new Object[] { true }, new Object[] { false });
}

@Before
public void setup() throws Exception {
Utils.storeSparseModel(client());
Utils.storeDenseModel(
client(),
randomIntBetween(1, 100),
// dot product means that we need normalized vectors; it's not worth doing that in this test
randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values())),
// TODO: Allow element type BIT once TestDenseInferenceServiceExtension supports it
randomValueOtherThan(DenseVectorFieldMapper.ElementType.BIT, () -> randomFrom(DenseVectorFieldMapper.ElementType.values()))
);
}

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "basic").build();
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateInferencePlugin.class);
}

@Override
public Settings indexSettings() {
var builder = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10))
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat);
return builder.build();
}

public void testLicenseInvalidForInference() {
prepareCreate(INDEX_NAME).setMapping(
String.format(
Locale.ROOT,
"""
{
"properties": {
"sparse_field": {
"type": "semantic_text",
"inference_id": "%s"
},
"dense_field": {
"type": "semantic_text",
"inference_id": "%s"
}
}
}
""",
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
TestDenseInferenceServiceExtension.TestInferenceService.NAME
)
).get();

BulkRequestBuilder bulkRequest = client().prepareBulk();
int totalBulkReqs = randomIntBetween(2, 100);
for (int i = 0; i < totalBulkReqs; i++) {
Map<String, Object> source = new HashMap<>();
source.put("sparse_field", randomSemanticTextInput());
source.put("dense_field", randomSemanticTextInput());

bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source));
}

BulkResponse bulkResponse = bulkRequest.get();
for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) {
assertTrue(bulkItemResponse.isFailed());
assertThat(bulkItemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class));
assertThat(
bulkItemResponse.getFailure().getCause().getMessage(),
containsString(Strings.format("current license is non-compliant for [%s]", XPackField.INFERENCE))
);
}
}

public void testNullSourceSucceeds() {
prepareCreate(INDEX_NAME).setMapping(
String.format(
Locale.ROOT,
"""
{
"properties": {
"sparse_field": {
"type": "semantic_text",
"inference_id": "%s"
},
"dense_field": {
"type": "semantic_text",
"inference_id": "%s"
}
}
}
""",
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
TestDenseInferenceServiceExtension.TestInferenceService.NAME
)
).get();

BulkRequestBuilder bulkRequest = client().prepareBulk();
int totalBulkReqs = randomIntBetween(2, 100);
Map<String, Object> source = new HashMap<>();
source.put("sparse_field", null);
source.put("dense_field", null);
for (int i = 0; i < totalBulkReqs; i++) {
bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source));
}

BulkResponse bulkResponse = bulkRequest.get();
assertFalse(bulkResponse.hasFailures());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
Expand Down Expand Up @@ -81,6 +82,11 @@ public void setup() throws Exception {
);
}

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateInferencePlugin.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ public Collection<?> createComponents(PluginServices services) {
}
inferenceServiceRegistry.set(serviceRegistry);

var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry);
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState());
shardBulkInferenceActionFilter.set(actionFilter);

var meterRegistry = services.telemetryProvider().getMeterRegistry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.XContent;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
Expand All @@ -58,6 +61,8 @@
import java.util.Map;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;

/**
* A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
* as {@link SemanticTextFieldMapper} in the index mapping. For each semantic text field referencing fields in
Expand All @@ -76,25 +81,29 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
private final ClusterService clusterService;
private final InferenceServiceRegistry inferenceServiceRegistry;
private final ModelRegistry modelRegistry;
private final XPackLicenseState licenseState;
private final int batchSize;

public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry
ModelRegistry modelRegistry,
XPackLicenseState licenseState
) {
this(clusterService, inferenceServiceRegistry, modelRegistry, DEFAULT_BATCH_SIZE);
this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE);
}

public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry,
XPackLicenseState licenseState,
int batchSize
) {
this.clusterService = clusterService;
this.inferenceServiceRegistry = inferenceServiceRegistry;
this.modelRegistry = modelRegistry;
this.licenseState = licenseState;
this.batchSize = batchSize;
}

Expand Down Expand Up @@ -561,6 +570,11 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
break;
}

if (INFERENCE_API_FEATURE.check(licenseState) == false) {
addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE));
break;
}

List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
int offsetAdjustment = 0;
for (String v : values) {
Expand Down
Loading