Skip to content

Commit bf420c8

Browse files
committed
Added unit test
1 parent 8a28093 commit bf420c8

File tree

1 file changed

+88
-2
lines changed

1 file changed

+88
-2
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ public void testIndexingPressure() throws Exception {
613613
}
614614

615615
@SuppressWarnings("unchecked")
616-
public void testIndexingPressureTripsOnEstimatedInferenceBytes() throws Exception {
616+
public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Exception {
617617
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
618618
Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build()
619619
);
@@ -637,7 +637,8 @@ public void testIndexingPressureTripsOnEstimatedInferenceBytes() throws Exceptio
637637
assertNull(request.items()[0].getPrimaryResponse());
638638
assertNull(request.items()[2].getPrimaryResponse());
639639

640-
BulkItemResponse doc1Response = request.items()[1].getPrimaryResponse();
640+
BulkItemRequest doc1Request = request.items()[1];
641+
BulkItemResponse doc1Response = doc1Request.getPrimaryResponse();
641642
assertNotNull(doc1Response);
642643
assertTrue(doc1Response.isFailed());
643644
BulkItemResponse.Failure doc1Failure = doc1Response.getFailure();
@@ -648,6 +649,10 @@ public void testIndexingPressureTripsOnEstimatedInferenceBytes() throws Exceptio
648649
assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class));
649650
assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS));
650651

652+
IndexRequest doc1IndexRequest = getIndexRequestOrNull(doc1Request.request());
653+
assertThat(doc1IndexRequest, notNullValue());
654+
assertThat(doc1IndexRequest.source(), equalTo(BytesReference.bytes(doc1Source)));
655+
651656
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
652657
assertThat(coordinatingIndexingPressure, notNullValue());
653658
verify(coordinatingIndexingPressure).increment(1, BytesReference.bytes(doc1Source).ramBytesUsed());
@@ -685,6 +690,87 @@ public void testIndexingPressureTripsOnEstimatedInferenceBytes() throws Exceptio
685690
verify(coordinatingIndexingPressure).close();
686691
}
687692

693+
@SuppressWarnings("unchecked")
694+
public void testIndexingPressureTripsOnInferenceResponseHandling() throws Exception {
695+
final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
696+
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
697+
Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (BytesReference.bytes(doc1Source).ramBytesUsed() + 1) + "b").build()
698+
);
699+
700+
final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
701+
sparseModel.putResult("bar", randomChunkedInferenceEmbedding(sparseModel, List.of("bar")));
702+
703+
final ShardBulkInferenceActionFilter filter = createFilter(
704+
threadPool,
705+
Map.of(sparseModel.getInferenceEntityId(), sparseModel),
706+
useLegacyFormat,
707+
true
708+
);
709+
filter.setIndexingPressure(indexingPressure);
710+
711+
CountDownLatch chainExecuted = new CountDownLatch(1);
712+
ActionFilterChain<BulkShardRequest, BulkShardResponse> actionFilterChain = (task, action, request, listener) -> {
713+
try {
714+
assertNull(request.getInferenceFieldMap());
715+
assertThat(request.items().length, equalTo(3));
716+
717+
assertNull(request.items()[0].getPrimaryResponse());
718+
assertNull(request.items()[2].getPrimaryResponse());
719+
720+
BulkItemRequest doc1Request = request.items()[1];
721+
BulkItemResponse doc1Response = doc1Request.getPrimaryResponse();
722+
assertNotNull(doc1Response);
723+
assertTrue(doc1Response.isFailed());
724+
BulkItemResponse.Failure doc1Failure = doc1Response.getFailure();
725+
assertThat(
726+
doc1Failure.getCause().getMessage(),
727+
containsString("Insufficient memory available to insert inference results into document [doc_1]")
728+
);
729+
assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class));
730+
assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS));
731+
732+
IndexRequest doc1IndexRequest = getIndexRequestOrNull(doc1Request.request());
733+
assertThat(doc1IndexRequest, notNullValue());
734+
assertThat(doc1IndexRequest.source(), equalTo(BytesReference.bytes(doc1Source)));
735+
736+
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
737+
assertThat(coordinatingIndexingPressure, notNullValue());
738+
verify(coordinatingIndexingPressure).increment(1, BytesReference.bytes(doc1Source).ramBytesUsed());
739+
verify(coordinatingIndexingPressure).increment(eq(0), longThat(l -> l > 0));
740+
verify(coordinatingIndexingPressure, times(2)).increment(anyInt(), anyLong());
741+
742+
// Verify that the coordinating indexing pressure is maintained through downstream action filters
743+
verify(coordinatingIndexingPressure, never()).close();
744+
745+
// Call the listener once the request is successfully processed, like is done in the production code path
746+
listener.onResponse(null);
747+
} finally {
748+
chainExecuted.countDown();
749+
}
750+
};
751+
ActionListener<BulkShardResponse> actionListener = (ActionListener<BulkShardResponse>) mock(ActionListener.class);
752+
Task task = mock(Task.class);
753+
754+
Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
755+
"sparse_field",
756+
new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null)
757+
);
758+
759+
BulkItemRequest[] items = new BulkItemRequest[3];
760+
items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo"));
761+
items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source));
762+
items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source("non_inference_field", "baz"));
763+
764+
BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
765+
request.setInferenceFieldMap(inferenceFieldMap);
766+
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
767+
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
768+
769+
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
770+
assertThat(coordinatingIndexingPressure, notNullValue());
771+
verify(coordinatingIndexingPressure).close();
772+
}
773+
688774
@SuppressWarnings("unchecked")
689775
private static ShardBulkInferenceActionFilter createFilter(
690776
ThreadPool threadPool,

0 commit comments

Comments
 (0)