@@ -613,7 +613,7 @@ public void testIndexingPressure() throws Exception {
613613 }
614614
615615 @ SuppressWarnings ("unchecked" )
616- public void testIndexingPressureTripsOnEstimatedInferenceBytes () throws Exception {
616+ public void testIndexingPressureTripsOnInferenceRequestGeneration () throws Exception {
617617 final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure (
618618 Settings .builder ().put (MAX_COORDINATING_BYTES .getKey (), "1b" ).build ()
619619 );
@@ -637,7 +637,8 @@ public void testIndexingPressureTripsOnEstimatedInferenceBytes() throws Exceptio
637637 assertNull (request .items ()[0 ].getPrimaryResponse ());
638638 assertNull (request .items ()[2 ].getPrimaryResponse ());
639639
640- BulkItemResponse doc1Response = request .items ()[1 ].getPrimaryResponse ();
640+ BulkItemRequest doc1Request = request .items ()[1 ];
641+ BulkItemResponse doc1Response = doc1Request .getPrimaryResponse ();
641642 assertNotNull (doc1Response );
642643 assertTrue (doc1Response .isFailed ());
643644 BulkItemResponse .Failure doc1Failure = doc1Response .getFailure ();
@@ -648,6 +649,10 @@ public void testIndexingPressureTripsOnEstimatedInferenceBytes() throws Exceptio
648649 assertThat (doc1Failure .getCause ().getCause (), instanceOf (EsRejectedExecutionException .class ));
649650 assertThat (doc1Failure .getStatus (), is (RestStatus .TOO_MANY_REQUESTS ));
650651
652+ IndexRequest doc1IndexRequest = getIndexRequestOrNull (doc1Request .request ());
653+ assertThat (doc1IndexRequest , notNullValue ());
654+ assertThat (doc1IndexRequest .source (), equalTo (BytesReference .bytes (doc1Source )));
655+
651656 IndexingPressure .Coordinating coordinatingIndexingPressure = indexingPressure .getCoordinating ();
652657 assertThat (coordinatingIndexingPressure , notNullValue ());
653658 verify (coordinatingIndexingPressure ).increment (1 , BytesReference .bytes (doc1Source ).ramBytesUsed ());
@@ -685,6 +690,87 @@ public void testIndexingPressureTripsOnEstimatedInferenceBytes() throws Exceptio
685690 verify (coordinatingIndexingPressure ).close ();
686691 }
687692
693+ @ SuppressWarnings ("unchecked" )
694+ public void testIndexingPressureTripsOnInferenceResponseHandling () throws Exception {
695+ final XContentBuilder doc1Source = IndexRequest .getXContentBuilder (XContentType .JSON , "sparse_field" , "bar" );
696+ final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure (
697+ Settings .builder ().put (MAX_COORDINATING_BYTES .getKey (), (BytesReference .bytes (doc1Source ).ramBytesUsed () + 1 ) + "b" ).build ()
698+ );
699+
700+ final StaticModel sparseModel = StaticModel .createRandomInstance (TaskType .SPARSE_EMBEDDING );
701+ sparseModel .putResult ("bar" , randomChunkedInferenceEmbedding (sparseModel , List .of ("bar" )));
702+
703+ final ShardBulkInferenceActionFilter filter = createFilter (
704+ threadPool ,
705+ Map .of (sparseModel .getInferenceEntityId (), sparseModel ),
706+ useLegacyFormat ,
707+ true
708+ );
709+ filter .setIndexingPressure (indexingPressure );
710+
711+ CountDownLatch chainExecuted = new CountDownLatch (1 );
712+ ActionFilterChain <BulkShardRequest , BulkShardResponse > actionFilterChain = (task , action , request , listener ) -> {
713+ try {
714+ assertNull (request .getInferenceFieldMap ());
715+ assertThat (request .items ().length , equalTo (3 ));
716+
717+ assertNull (request .items ()[0 ].getPrimaryResponse ());
718+ assertNull (request .items ()[2 ].getPrimaryResponse ());
719+
720+ BulkItemRequest doc1Request = request .items ()[1 ];
721+ BulkItemResponse doc1Response = doc1Request .getPrimaryResponse ();
722+ assertNotNull (doc1Response );
723+ assertTrue (doc1Response .isFailed ());
724+ BulkItemResponse .Failure doc1Failure = doc1Response .getFailure ();
725+ assertThat (
726+ doc1Failure .getCause ().getMessage (),
727+ containsString ("Insufficient memory available to insert inference results into document [doc_1]" )
728+ );
729+ assertThat (doc1Failure .getCause ().getCause (), instanceOf (EsRejectedExecutionException .class ));
730+ assertThat (doc1Failure .getStatus (), is (RestStatus .TOO_MANY_REQUESTS ));
731+
732+ IndexRequest doc1IndexRequest = getIndexRequestOrNull (doc1Request .request ());
733+ assertThat (doc1IndexRequest , notNullValue ());
734+ assertThat (doc1IndexRequest .source (), equalTo (BytesReference .bytes (doc1Source )));
735+
736+ IndexingPressure .Coordinating coordinatingIndexingPressure = indexingPressure .getCoordinating ();
737+ assertThat (coordinatingIndexingPressure , notNullValue ());
738+ verify (coordinatingIndexingPressure ).increment (1 , BytesReference .bytes (doc1Source ).ramBytesUsed ());
739+ verify (coordinatingIndexingPressure ).increment (eq (0 ), longThat (l -> l > 0 ));
740+ verify (coordinatingIndexingPressure , times (2 )).increment (anyInt (), anyLong ());
741+
742+ // Verify that the coordinating indexing pressure is maintained through downstream action filters
743+ verify (coordinatingIndexingPressure , never ()).close ();
744+
745+ // Call the listener once the request is successfully processed, like is done in the production code path
746+ listener .onResponse (null );
747+ } finally {
748+ chainExecuted .countDown ();
749+ }
750+ };
751+ ActionListener <BulkShardResponse > actionListener = (ActionListener <BulkShardResponse >) mock (ActionListener .class );
752+ Task task = mock (Task .class );
753+
754+ Map <String , InferenceFieldMetadata > inferenceFieldMap = Map .of (
755+ "sparse_field" ,
756+ new InferenceFieldMetadata ("sparse_field" , sparseModel .getInferenceEntityId (), new String [] { "sparse_field" }, null )
757+ );
758+
759+ BulkItemRequest [] items = new BulkItemRequest [3 ];
760+ items [0 ] = new BulkItemRequest (0 , new IndexRequest ("index" ).id ("doc_0" ).source ("non_inference_field" , "foo" ));
761+ items [1 ] = new BulkItemRequest (1 , new IndexRequest ("index" ).id ("doc_1" ).source (doc1Source ));
762+ items [2 ] = new BulkItemRequest (2 , new IndexRequest ("index" ).id ("doc_2" ).source ("non_inference_field" , "baz" ));
763+
764+ BulkShardRequest request = new BulkShardRequest (new ShardId ("test" , "test" , 0 ), WriteRequest .RefreshPolicy .NONE , items );
765+ request .setInferenceFieldMap (inferenceFieldMap );
766+ filter .apply (task , TransportShardBulkAction .ACTION_NAME , request , actionListener , actionFilterChain );
767+ awaitLatch (chainExecuted , 10 , TimeUnit .SECONDS );
768+
769+ IndexingPressure .Coordinating coordinatingIndexingPressure = indexingPressure .getCoordinating ();
770+ assertThat (coordinatingIndexingPressure , notNullValue ());
771+ verify (coordinatingIndexingPressure ).close ();
772+ }
773+
688774 @ SuppressWarnings ("unchecked" )
689775 private static ShardBulkInferenceActionFilter createFilter (
690776 ThreadPool threadPool ,
0 commit comments