Skip to content

Commit e5f64ff

Browse files
committed
Fix testIndexingPressureTripsOnEstimatedInferenceBytes
1 parent fcf7387 commit e5f64ff

File tree

1 file changed

+38
-12
lines changed

1 file changed

+38
-12
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
import org.elasticsearch.xpack.core.XPackField;
6161
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
6262
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
63-
import org.elasticsearch.xpack.inference.InferenceException;
6463
import org.elasticsearch.xpack.inference.InferencePlugin;
6564
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
6665
import org.elasticsearch.xpack.inference.model.TestModel;
@@ -613,7 +612,7 @@ public void testIndexingPressure() throws Exception {
613612
}
614613

615614
@SuppressWarnings("unchecked")
616-
public void testIndexingPressureTripsOnEstimatedInferenceBytes() {
615+
public void testIndexingPressureTripsOnEstimatedInferenceBytes() throws Exception {
617616
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
618617
Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build()
619618
);
@@ -626,8 +625,41 @@ public void testIndexingPressureTripsOnEstimatedInferenceBytes() {
626625
);
627626
filter.setIndexingPressure(indexingPressure);
628627

628+
XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
629+
630+
CountDownLatch chainExecuted = new CountDownLatch(1);
629631
ActionFilterChain<BulkShardRequest, BulkShardResponse> actionFilterChain = (task, action, request, listener) -> {
630-
fail("Downstream elements of the action filter chain should not execute");
632+
try {
633+
assertNull(request.getInferenceFieldMap());
634+
assertThat(request.items().length, equalTo(3));
635+
636+
assertNull(request.items()[0].getPrimaryResponse());
637+
assertNull(request.items()[2].getPrimaryResponse());
638+
639+
BulkItemResponse doc1Response = request.items()[1].getPrimaryResponse();
640+
assertNotNull(doc1Response);
641+
assertTrue(doc1Response.isFailed());
642+
BulkItemResponse.Failure doc1Failure = doc1Response.getFailure();
643+
assertThat(
644+
doc1Failure.getCause().getMessage(),
645+
containsString("Insufficient memory available to update source on document [doc_1]")
646+
);
647+
assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class));
648+
assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS));
649+
650+
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
651+
assertThat(coordinatingIndexingPressure, notNullValue());
652+
verify(coordinatingIndexingPressure).increment(1, BytesReference.bytes(doc1Source).ramBytesUsed());
653+
verify(coordinatingIndexingPressure, times(1)).increment(anyInt(), anyLong());
654+
655+
// Verify that the coordinating indexing pressure is maintained through downstream action filters
656+
verify(coordinatingIndexingPressure, never()).close();
657+
658+
// Call the listener once the request is successfully processed, like is done in the production code path
659+
listener.onResponse(null);
660+
} finally {
661+
chainExecuted.countDown();
662+
}
631663
};
632664
ActionListener<BulkShardResponse> actionListener = (ActionListener<BulkShardResponse>) mock(ActionListener.class);
633665
Task task = mock(Task.class);
@@ -639,19 +671,13 @@ public void testIndexingPressureTripsOnEstimatedInferenceBytes() {
639671

640672
BulkItemRequest[] items = new BulkItemRequest[3];
641673
items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo"));
642-
items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source("sparse_field", "bar"));
674+
items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source));
643675
items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source("non_inference_field", "baz"));
644676

645677
BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
646678
request.setInferenceFieldMap(inferenceFieldMap);
647-
648-
InferenceException exception = assertThrows(
649-
InferenceException.class,
650-
() -> filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain)
651-
);
652-
assertThat(exception.getMessage(), containsString("Insufficient memory available to perform inference on bulk request"));
653-
assertThat(exception.status(), equalTo(RestStatus.TOO_MANY_REQUESTS));
654-
assertThat(exception.getCause(), instanceOf(EsRejectedExecutionException.class));
679+
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
680+
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
655681

656682
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
657683
assertThat(coordinatingIndexingPressure, notNullValue());

0 commit comments

Comments
 (0)