Skip to content

Commit af3f1e9

Browse files
Cleaning up tests
1 parent 7604667 commit af3f1e9

File tree

2 files changed

+67
-42
lines changed

2 files changed

+67
-42
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterBasicLicenseIT.java

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,20 @@
77

88
package org.elasticsearch.xpack.inference.action.filter;
99

10-
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
11-
1210
import org.elasticsearch.ElasticsearchSecurityException;
1311
import org.elasticsearch.action.bulk.BulkItemResponse;
1412
import org.elasticsearch.action.bulk.BulkRequestBuilder;
1513
import org.elasticsearch.action.bulk.BulkResponse;
1614
import org.elasticsearch.action.index.IndexRequestBuilder;
1715
import org.elasticsearch.cluster.metadata.IndexMetadata;
1816
import org.elasticsearch.common.settings.Settings;
19-
import org.elasticsearch.index.IndexSettings;
20-
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
21-
import org.elasticsearch.index.mapper.SourceFieldMapper;
17+
import org.elasticsearch.core.Strings;
2218
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
2319
import org.elasticsearch.inference.SimilarityMeasure;
2420
import org.elasticsearch.license.LicenseSettings;
2521
import org.elasticsearch.plugins.Plugin;
2622
import org.elasticsearch.test.ESIntegTestCase;
23+
import org.elasticsearch.xpack.core.XPackField;
2724
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
2825
import org.elasticsearch.xpack.inference.Utils;
2926
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
@@ -33,34 +30,16 @@
3330
import java.util.Arrays;
3431
import java.util.Collection;
3532
import java.util.HashMap;
36-
import java.util.List;
3733
import java.util.Locale;
3834
import java.util.Map;
3935

4036
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
37+
import static org.hamcrest.Matchers.containsString;
4138
import static org.hamcrest.Matchers.instanceOf;
4239

4340
public class ShardBulkInferenceActionFilterBasicLicenseIT extends ESIntegTestCase {
4441
public static final String INDEX_NAME = "test-index";
4542

46-
private final boolean useLegacyFormat;
47-
private final boolean useSyntheticSource;
48-
49-
public ShardBulkInferenceActionFilterBasicLicenseIT(boolean useLegacyFormat, boolean useSyntheticSource) {
50-
this.useLegacyFormat = useLegacyFormat;
51-
this.useSyntheticSource = useSyntheticSource;
52-
}
53-
54-
@ParametersFactory
55-
public static Iterable<Object[]> parameters() throws Exception {
56-
return List.of(
57-
new Object[] { true, false },
58-
new Object[] { true, true },
59-
new Object[] { false, false },
60-
new Object[] { false, true }
61-
);
62-
}
63-
6443
@Before
6544
public void setup() throws Exception {
6645
Utils.storeSparseModel(client());
@@ -86,13 +65,7 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {
8665

8766
@Override
8867
public Settings indexSettings() {
89-
var builder = Settings.builder()
90-
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10))
91-
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat);
92-
if (useSyntheticSource) {
93-
builder.put(IndexSettings.RECOVERY_USE_SYNTHETIC_SOURCE_SETTING.getKey(), true);
94-
builder.put(IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SourceFieldMapper.Mode.SYNTHETIC.name());
95-
}
68+
var builder = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10));
9669
return builder.build();
9770
}
9871

@@ -123,16 +96,56 @@ public void testLicenseInvalidForInference() {
12396
int totalBulkReqs = randomIntBetween(2, 100);
12497
for (int i = 0; i < totalBulkReqs; i++) {
12598
Map<String, Object> source = new HashMap<>();
126-
source.put("sparse_field", rarely() ? null : randomSemanticTextInput());
127-
source.put("dense_field", rarely() ? null : randomSemanticTextInput());
99+
source.put("sparse_field", randomSemanticTextInput());
100+
source.put("dense_field", randomSemanticTextInput());
128101

129102
bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source));
130103
}
131104

132105
BulkResponse bulkResponse = bulkRequest.get();
133-
for (BulkItemResponse itemResponse : bulkResponse) {
134-
assertTrue(itemResponse.isFailed());
135-
assertThat(itemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class));
106+
for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) {
107+
assertTrue(bulkItemResponse.isFailed());
108+
assertThat(bulkItemResponse.getFailure().getCause(), instanceOf(ElasticsearchSecurityException.class));
109+
assertThat(
110+
bulkItemResponse.getFailure().getCause().getMessage(),
111+
containsString(Strings.format("current license is non-compliant for [%s]", XPackField.INFERENCE))
112+
);
136113
}
137114
}
115+
116+
public void testNullSourceSucceeds() {
117+
prepareCreate(INDEX_NAME).setMapping(
118+
String.format(
119+
Locale.ROOT,
120+
"""
121+
{
122+
"properties": {
123+
"sparse_field": {
124+
"type": "semantic_text",
125+
"inference_id": "%s"
126+
},
127+
"dense_field": {
128+
"type": "semantic_text",
129+
"inference_id": "%s"
130+
}
131+
}
132+
}
133+
""",
134+
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
135+
TestDenseInferenceServiceExtension.TestInferenceService.NAME
136+
)
137+
).get();
138+
139+
BulkRequestBuilder bulkRequest = client().prepareBulk();
140+
int totalBulkReqs = randomIntBetween(2, 100);
141+
Map<String, Object> source = new HashMap<>();
142+
source.put("sparse_field", null);
143+
source.put("dense_field", null);
144+
for (int i = 0; i < totalBulkReqs; i++) {
145+
bulkRequest.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(Long.toString(i)).setSource(source));
146+
}
147+
148+
BulkResponse bulkResponse = bulkRequest.get();
149+
assertFalse(bulkResponse.hasFailures());
150+
}
138151
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.elasticsearch.threadpool.ThreadPool;
5050
import org.elasticsearch.xcontent.XContentType;
5151
import org.elasticsearch.xcontent.json.JsonXContent;
52+
import org.elasticsearch.xpack.core.XPackField;
5253
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
5354
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
5455
import org.elasticsearch.xpack.inference.InferencePlugin;
@@ -140,16 +141,26 @@ public void testFilterNoop() throws Exception {
140141
}
141142

142143
@SuppressWarnings({ "unchecked", "rawtypes" })
143-
public void testLicenseInvalidForInference() {
144+
public void testLicenseInvalidForInference() throws InterruptedException {
144145
StaticModel model = StaticModel.createRandomInstance();
145146
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false);
147+
CountDownLatch chainExecuted = new CountDownLatch(1);
146148
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
147-
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
148-
assertThat(bulkShardRequest.items().length, equalTo(1));
149+
try {
150+
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
151+
assertThat(bulkShardRequest.items().length, equalTo(1));
152+
153+
BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure();
154+
assertNotNull(failure);
155+
assertThat(failure.getCause(), instanceOf(ElasticsearchSecurityException.class));
156+
assertThat(
157+
failure.getMessage(),
158+
containsString(org.elasticsearch.core.Strings.format("current license is non-compliant for [%s]", XPackField.INFERENCE))
159+
);
160+
} finally {
161+
chainExecuted.countDown();
162+
}
149163

150-
BulkItemResponse.Failure failure = bulkShardRequest.items()[0].getPrimaryResponse().getFailure();
151-
assertNotNull(failure);
152-
assertThat(failure.getCause(), instanceOf(ElasticsearchSecurityException.class));
153164
};
154165
ActionListener actionListener = mock(ActionListener.class);
155166
Task task = mock(Task.class);
@@ -164,6 +175,7 @@ public void testLicenseInvalidForInference() {
164175
request.setInferenceFieldMap(inferenceFieldMap);
165176

166177
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
178+
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
167179
}
168180

169181
@SuppressWarnings({ "unchecked", "rawtypes" })

0 commit comments

Comments
 (0)