6060import org .elasticsearch .xpack .core .XPackField ;
6161import org .elasticsearch .xpack .core .inference .results .ChunkedInferenceEmbedding ;
6262import org .elasticsearch .xpack .core .inference .results .ChunkedInferenceError ;
63- import org .elasticsearch .xpack .inference .InferenceException ;
6463import org .elasticsearch .xpack .inference .InferencePlugin ;
6564import org .elasticsearch .xpack .inference .mapper .SemanticTextField ;
6665import org .elasticsearch .xpack .inference .model .TestModel ;
@@ -613,7 +612,7 @@ public void testIndexingPressure() throws Exception {
613612 }
614613
615614 @ SuppressWarnings ("unchecked" )
616- public void testIndexingPressureTripsOnEstimatedInferenceBytes () {
615+ public void testIndexingPressureTripsOnEstimatedInferenceBytes () throws Exception {
617616 final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure (
618617 Settings .builder ().put (MAX_COORDINATING_BYTES .getKey (), "1b" ).build ()
619618 );
@@ -626,8 +625,41 @@ public void testIndexingPressureTripsOnEstimatedInferenceBytes() {
626625 );
627626 filter .setIndexingPressure (indexingPressure );
628627
628+ XContentBuilder doc1Source = IndexRequest .getXContentBuilder (XContentType .JSON , "sparse_field" , "bar" );
629+
630+ CountDownLatch chainExecuted = new CountDownLatch (1 );
629631 ActionFilterChain <BulkShardRequest , BulkShardResponse > actionFilterChain = (task , action , request , listener ) -> {
630- fail ("Downstream elements of the action filter chain should not execute" );
632+ try {
633+ assertNull (request .getInferenceFieldMap ());
634+ assertThat (request .items ().length , equalTo (3 ));
635+
636+ assertNull (request .items ()[0 ].getPrimaryResponse ());
637+ assertNull (request .items ()[2 ].getPrimaryResponse ());
638+
639+ BulkItemResponse doc1Response = request .items ()[1 ].getPrimaryResponse ();
640+ assertNotNull (doc1Response );
641+ assertTrue (doc1Response .isFailed ());
642+ BulkItemResponse .Failure doc1Failure = doc1Response .getFailure ();
643+ assertThat (
644+ doc1Failure .getCause ().getMessage (),
645+ containsString ("Insufficient memory available to update source on document [doc_1]" )
646+ );
647+ assertThat (doc1Failure .getCause ().getCause (), instanceOf (EsRejectedExecutionException .class ));
648+ assertThat (doc1Failure .getStatus (), is (RestStatus .TOO_MANY_REQUESTS ));
649+
650+ IndexingPressure .Coordinating coordinatingIndexingPressure = indexingPressure .getCoordinating ();
651+ assertThat (coordinatingIndexingPressure , notNullValue ());
652+ verify (coordinatingIndexingPressure ).increment (1 , BytesReference .bytes (doc1Source ).ramBytesUsed ());
653+ verify (coordinatingIndexingPressure , times (1 )).increment (anyInt (), anyLong ());
654+
655+ // Verify that the coordinating indexing pressure is maintained through downstream action filters
656+ verify (coordinatingIndexingPressure , never ()).close ();
657+
658+ // Call the listener once the request is successfully processed, like is done in the production code path
659+ listener .onResponse (null );
660+ } finally {
661+ chainExecuted .countDown ();
662+ }
631663 };
632664 ActionListener <BulkShardResponse > actionListener = (ActionListener <BulkShardResponse >) mock (ActionListener .class );
633665 Task task = mock (Task .class );
@@ -639,19 +671,13 @@ public void testIndexingPressureTripsOnEstimatedInferenceBytes() {
639671
640672 BulkItemRequest [] items = new BulkItemRequest [3 ];
641673 items [0 ] = new BulkItemRequest (0 , new IndexRequest ("index" ).id ("doc_0" ).source ("non_inference_field" , "foo" ));
642- items [1 ] = new BulkItemRequest (1 , new IndexRequest ("index" ).id ("doc_1" ).source ("sparse_field" , "bar" ));
674+ items [1 ] = new BulkItemRequest (1 , new IndexRequest ("index" ).id ("doc_1" ).source (doc1Source ));
643675 items [2 ] = new BulkItemRequest (2 , new IndexRequest ("index" ).id ("doc_2" ).source ("non_inference_field" , "baz" ));
644676
645677 BulkShardRequest request = new BulkShardRequest (new ShardId ("test" , "test" , 0 ), WriteRequest .RefreshPolicy .NONE , items );
646678 request .setInferenceFieldMap (inferenceFieldMap );
647-
648- InferenceException exception = assertThrows (
649- InferenceException .class ,
650- () -> filter .apply (task , TransportShardBulkAction .ACTION_NAME , request , actionListener , actionFilterChain )
651- );
652- assertThat (exception .getMessage (), containsString ("Insufficient memory available to perform inference on bulk request" ));
653- assertThat (exception .status (), equalTo (RestStatus .TOO_MANY_REQUESTS ));
654- assertThat (exception .getCause (), instanceOf (EsRejectedExecutionException .class ));
679+ filter .apply (task , TransportShardBulkAction .ACTION_NAME , request , actionListener , actionFilterChain );
680+ awaitLatch (chainExecuted , 10 , TimeUnit .SECONDS );
655681
656682 IndexingPressure .Coordinating coordinatingIndexingPressure = indexingPressure .getCoordinating ();
657683 assertThat (coordinatingIndexingPressure , notNullValue ());
0 commit comments