Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,20 @@

package org.elasticsearch.xpack.inference.action.filter;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.core.Strings;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
Expand All @@ -33,34 +30,16 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.instanceOf;

public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase {
public static final String INDEX_NAME = "test-index";

private final boolean useLegacyFormat;
private final boolean useSyntheticSource;

public ShardBulkInferenceActionFilterBasicLicenseIT(boolean useLegacyFormat, boolean useSyntheticSource) {
this.useLegacyFormat = useLegacyFormat;
this.useSyntheticSource = useSyntheticSource;
}

@ParametersFactory
public static Iterable<Object[]> parameters() throws Exception {
return List.of(
new Object[] { true, false },
new Object[] { true, true },
new Object[] { false, false },
new Object[] { false, true }
);
}

@Before
public void setup() throws Exception {
Utils.storeSparseModel(client());
Expand All @@ -86,13 +65,7 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {

@Override
public Settings indexSettings() {
var builder = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10))
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat);
if (useSyntheticSource) {
builder.put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true);
builder.put(IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SourceFieldMapper.Mode.SYNTHETIC.name());
}
var builder = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10));
return builder.build();
}

Expand Down Expand Up @@ -123,16 +96,56 @@ public void testLicenseInvalidForInference() {
int totalBulkReqs = randomIntBetween(2, 100);
for (int i = 0; i < totalBulkReqs; i++) {
Map<String, Object> source = new HashMap<>();
source.put("sparse_field", rarely() ? null : randomSemanticTextInput());
source.put("dense_field", rarely() ? null : randomSemanticTextInput());
source.put("sparse_field", randomSemanticTextInput());
source.put("dense_field", randomSemanticTextInput());

bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source));
}

BulkResponse bulkResponse = bulkRequest.get();
for (BulkItemResponse itemResponse : bulkResponse) {
assertTrue(itemResponse.isFailed());
assertThat(itemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class));
for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) {
assertTrue(bulkItemResponse.isFailed());
assertThat(bulkItemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class));
assertThat(
bulkItemResponse.getFailure().getCause().getMessage(),
containsString(Strings.format("current license is non-compliant for [%s]", XPackField.INFERENCE))
);
}
}

public void testNullSourceSucceeds() {
prepareCreate(INDEX_NAME).setMapping(
String.format(
Locale.ROOT,
"""
{
"properties": {
"sparse_field": {
"type": "semantic_text",
"inference_id": "%s"
},
"dense_field": {
"type": "semantic_text",
"inference_id": "%s"
}
}
}
""",
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
TestDenseInferenceServiceExtension.TestInferenceService.NAME
)
).get();

BulkRequestBuilder bulkRequest = client().prepareBulk();
int totalBulkReqs = randomIntBetween(2, 100);
Map<String, Object> source = new HashMap<>();
source.put("sparse_field", null);
source.put("dense_field", null);
for (int i = 0; i < totalBulkReqs; i++) {
bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source));
}

BulkResponse bulkResponse = bulkRequest.get();
assertFalse(bulkResponse.hasFailures());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.inference.InferencePlugin;
Expand Down Expand Up @@ -140,16 +141,26 @@ public void testFilterNoop() throws Exception {
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testLicenseInvalidForInference() {
public void testLicenseInvalidForInference() throws InterruptedException {
StaticModel model = StaticModel.createRandomInstance();
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
assertThat(bulkShardRequest.items().length, equalTo(1));
try {
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
assertThat(bulkShardRequest.items().length, equalTo(1));

BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure();
assertNotNull(failure);
assertThat(failure.getCause(), instanceOf(ElasticsearchSecurityException.class));
assertThat(
failure.getMessage(),
containsString(org.elasticsearch.core.Strings.format("current license is non-compliant for [%s]", XPackField.INFERENCE))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, why didn't IntelliJ want to use the already-imported org.elasticsearch.common.Strings?

);
} finally {
chainExecuted.countDown();
}

BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure();
assertNotNull(failure);
assertThat(failure.getCause(), instanceOf(ElasticsearchSecurityException.class));
};
ActionListener actionListener = mock(ActionListener.class);
Task task = mock(Task.class);
Expand All @@ -164,6 +175,7 @@ public void testLicenseInvalidForInference() {
request.setInferenceFieldMap(inferenceFieldMap);

filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
Expand Down