Skip to content

Commit bb99b3b

Browse files
Update unit test to follow more accurate behavior
1 parent aeaf117 commit bb99b3b

File tree

1 file changed

+18
-23
lines changed

1 file changed

+18
-23
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ public void testExplicitNull() throws Exception {
367367
}
368368

369369
@SuppressWarnings({ "unchecked", "rawtypes" })
370-
public void testSkipGeneratingInference() throws Exception {
370+
public void testHandleEmptyInput() throws Exception {
371371
StaticModel model = StaticModel.createRandomInstance();
372372
ShardBulkInferenceActionFilter filter = createFilter(
373373
threadPool,
@@ -383,37 +383,33 @@ public void testSkipGeneratingInference() throws Exception {
383383
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
384384
IndexRequest actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[0].request());
385385

386-
// Create: Empty string
387-
assertThat(XContentMapValues.extractValue("obj", actualRequest.sourceAsMap(), EXPLICIT_NULL), equalTo(""));
388-
assertNull(XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME, actualRequest.sourceAsMap(), EXPLICIT_NULL));
386+
// Create with Empty string
387+
assertInferenceResults(useLegacyFormat, actualRequest, "semantic_text_field", useLegacyFormat ? EXPLICIT_NULL: "", 0);
389388

390-
// Create: whitespace only
389+
// Create with whitespace only
391390
actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[1].request());
392-
assertThat(XContentMapValues.extractValue("obj.field", actualRequest.sourceAsMap(), EXPLICIT_NULL), equalTo(""));
393-
assertNull(XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME, actualRequest.sourceAsMap(), EXPLICIT_NULL));
391+
assertInferenceResults(useLegacyFormat, actualRequest, "semantic_text_field", useLegacyFormat ? EXPLICIT_NULL: " ", 0);
394392

395-
// Update: Empty string
393+
// Update with multiple Whitespaces
396394
actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[2].request());
397-
assertThat(XContentMapValues.extractValue("obj", actualRequest.sourceAsMap(), EXPLICIT_NULL), equalTo(" "));
398-
assertNull(XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME, actualRequest.sourceAsMap(), EXPLICIT_NULL));
399-
400-
// Update: whitespace only
401-
actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[3].request());
402-
assertThat(XContentMapValues.extractValue("obj.field", actualRequest.sourceAsMap(), EXPLICIT_NULL), equalTo(" "));
403-
assertNull(XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME, actualRequest.sourceAsMap(), EXPLICIT_NULL));
395+
assertInferenceResults(useLegacyFormat, actualRequest, "semantic_text_field", useLegacyFormat ? EXPLICIT_NULL: " ", 0);
404396
} finally {
405397
chainExecuted.countDown();
406398
}
407399
};
408400
ActionListener actionListener = mock(ActionListener.class);
409401
Task task = mock(Task.class);
402+
Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
403+
"semantic_text_field",
404+
new InferenceFieldMetadata("semantic_text_field", model.getInferenceEntityId(), new String[] { "semantic_text_field" })
405+
);
410406

411-
BulkItemRequest[] items = new BulkItemRequest[4];
412-
items[0] = new BulkItemRequest(0, new IndexRequest("index").source(Map.of("obj", "")));
413-
items[1] = new BulkItemRequest(1, new IndexRequest("index").source(Map.of("obj", Map.of("field", ""))));
414-
items[2] = new BulkItemRequest(2, new UpdateRequest().doc(new IndexRequest("index").source(Map.of("obj", " "))));
415-
items[3] = new BulkItemRequest(3, new UpdateRequest().doc(new IndexRequest("index").source(Map.of("obj", Map.of("field", " ")))));
407+
BulkItemRequest[] items = new BulkItemRequest[3];
408+
items[0] = new BulkItemRequest(0, new IndexRequest("index").source(Map.of("semantic_text_field", "")));
409+
items[1] = new BulkItemRequest(1, new IndexRequest("index").source(Map.of("semantic_text_field", " ")));
410+
items[2] = new BulkItemRequest(2, new UpdateRequest().doc(new IndexRequest("index").source(Map.of("semantic_text_field", " "))));
416411
BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
412+
request.setInferenceFieldMap(inferenceFieldMap);
417413
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
418414
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
419415
}
@@ -655,9 +651,8 @@ private static void assertInferenceResults(
655651
assertNotNull(chunks);
656652
assertThat(chunks.size(), equalTo(expectedChunkCount));
657653
} else {
658-
// If the expected chunk count is 0, we expect that no inference has been performed. In this case, the source should not be
659-
// transformed, and thus the semantic text field structure should not be created.
660-
assertNull(chunks);
654+
// If the expected chunk count is 0, we expect that no inference has been performed.
655+
assertTrue(chunks == null || chunks.isEmpty());
661656
}
662657
} else {
663658
assertThat(XContentMapValues.extractValue(fieldName, requestMap, EXPLICIT_NULL), equalTo(expectedOriginalValue));

0 commit comments

Comments
 (0)