2828import  org .elasticsearch .common .settings .Setting ;
2929import  org .elasticsearch .common .unit .ByteSizeValue ;
3030import  org .elasticsearch .common .util .concurrent .AtomicArray ;
31+ import  org .elasticsearch .common .util .concurrent .EsRejectedExecutionException ;
3132import  org .elasticsearch .common .xcontent .XContentHelper ;
3233import  org .elasticsearch .common .xcontent .support .XContentMapValues ;
3334import  org .elasticsearch .core .Nullable ;
3435import  org .elasticsearch .core .Releasable ;
3536import  org .elasticsearch .core .TimeValue ;
37+ import  org .elasticsearch .index .IndexingPressure ;
3638import  org .elasticsearch .index .mapper .InferenceMetadataFieldsMapper ;
3739import  org .elasticsearch .inference .ChunkInferenceInput ;
3840import  org .elasticsearch .inference .ChunkedInference ;
@@ -108,18 +110,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
108110    private  final  InferenceServiceRegistry  inferenceServiceRegistry ;
109111    private  final  ModelRegistry  modelRegistry ;
110112    private  final  XPackLicenseState  licenseState ;
113+     private  final  IndexingPressure  indexingPressure ;
111114    private  volatile  long  batchSizeInBytes ;
112115
113116    public  ShardBulkInferenceActionFilter (
114117        ClusterService  clusterService ,
115118        InferenceServiceRegistry  inferenceServiceRegistry ,
116119        ModelRegistry  modelRegistry ,
117-         XPackLicenseState  licenseState 
120+         XPackLicenseState  licenseState ,
121+         IndexingPressure  indexingPressure 
118122    ) {
119123        this .clusterService  = clusterService ;
120124        this .inferenceServiceRegistry  = inferenceServiceRegistry ;
121125        this .modelRegistry  = modelRegistry ;
122126        this .licenseState  = licenseState ;
127+         this .indexingPressure  = indexingPressure ;
123128        this .batchSizeInBytes  = INDICES_INFERENCE_BATCH_SIZE .get (clusterService .getSettings ()).getBytes ();
124129        clusterService .getClusterSettings ().addSettingsUpdateConsumer (INDICES_INFERENCE_BATCH_SIZE , this ::setBatchSize );
125130    }
@@ -145,8 +150,15 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
145150            BulkShardRequest  bulkShardRequest  = (BulkShardRequest ) request ;
146151            var  fieldInferenceMetadata  = bulkShardRequest .consumeInferenceFieldMap ();
147152            if  (fieldInferenceMetadata  != null  && fieldInferenceMetadata .isEmpty () == false ) {
148-                 Runnable  onInferenceCompletion  = () -> chain .proceed (task , action , request , listener );
149-                 processBulkShardRequest (fieldInferenceMetadata , bulkShardRequest , onInferenceCompletion );
153+                 // Maintain coordinating indexing pressure from inference until the indexing operations are complete 
154+                 IndexingPressure .Coordinating  coordinatingIndexingPressure  = indexingPressure .createCoordinatingOperation (false );
155+                 Runnable  onInferenceCompletion  = () -> chain .proceed (
156+                     task ,
157+                     action ,
158+                     request ,
159+                     ActionListener .releaseAfter (listener , coordinatingIndexingPressure )
160+                 );
161+                 processBulkShardRequest (fieldInferenceMetadata , bulkShardRequest , onInferenceCompletion , coordinatingIndexingPressure );
150162                return ;
151163            }
152164        }
@@ -156,11 +168,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
156168    private  void  processBulkShardRequest (
157169        Map <String , InferenceFieldMetadata > fieldInferenceMap ,
158170        BulkShardRequest  bulkShardRequest ,
159-         Runnable  onCompletion 
171+         Runnable  onCompletion ,
172+         IndexingPressure .Coordinating  coordinatingIndexingPressure 
160173    ) {
161174        var  index  = clusterService .state ().getMetadata ().index (bulkShardRequest .index ());
162175        boolean  useLegacyFormat  = InferenceMetadataFieldsMapper .isEnabled (index .getSettings ()) == false ;
163-         new  AsyncBulkShardInferenceAction (useLegacyFormat , fieldInferenceMap , bulkShardRequest , onCompletion ).run ();
176+         new  AsyncBulkShardInferenceAction (useLegacyFormat , fieldInferenceMap , bulkShardRequest , onCompletion , coordinatingIndexingPressure )
177+             .run ();
164178    }
165179
166180    private  record  InferenceProvider (InferenceService  service , Model  model ) {}
@@ -230,18 +244,21 @@ private class AsyncBulkShardInferenceAction implements Runnable {
230244        private  final  BulkShardRequest  bulkShardRequest ;
231245        private  final  Runnable  onCompletion ;
232246        private  final  AtomicArray <FieldInferenceResponseAccumulator > inferenceResults ;
247+         private  final  IndexingPressure .Coordinating  coordinatingIndexingPressure ;
233248
234249        private  AsyncBulkShardInferenceAction (
235250            boolean  useLegacyFormat ,
236251            Map <String , InferenceFieldMetadata > fieldInferenceMap ,
237252            BulkShardRequest  bulkShardRequest ,
238-             Runnable  onCompletion 
253+             Runnable  onCompletion ,
254+             IndexingPressure .Coordinating  coordinatingIndexingPressure 
239255        ) {
240256            this .useLegacyFormat  = useLegacyFormat ;
241257            this .fieldInferenceMap  = fieldInferenceMap ;
242258            this .bulkShardRequest  = bulkShardRequest ;
243259            this .inferenceResults  = new  AtomicArray <>(bulkShardRequest .items ().length );
244260            this .onCompletion  = onCompletion ;
261+             this .coordinatingIndexingPressure  = coordinatingIndexingPressure ;
245262        }
246263
247264        @ Override 
@@ -429,9 +446,9 @@ public void onFailure(Exception exc) {
429446         */ 
430447        private  long  addFieldInferenceRequests (BulkItemRequest  item , int  itemIndex , Map <String , List <FieldInferenceRequest >> requestsMap ) {
431448            boolean  isUpdateRequest  = false ;
432-             final  IndexRequest  indexRequest ;
449+             final  IndexRequestWithIndexingPressure  indexRequest ;
433450            if  (item .request () instanceof  IndexRequest  ir ) {
434-                 indexRequest  = ir ;
451+                 indexRequest  = new   IndexRequestWithIndexingPressure ( ir ) ;
435452            } else  if  (item .request () instanceof  UpdateRequest  updateRequest ) {
436453                isUpdateRequest  = true ;
437454                if  (updateRequest .script () != null ) {
@@ -445,13 +462,13 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
445462                    );
446463                    return  0 ;
447464                }
448-                 indexRequest  = updateRequest .doc ();
465+                 indexRequest  = new   IndexRequestWithIndexingPressure ( updateRequest .doc () );
449466            } else  {
450467                // ignore delete request 
451468                return  0 ;
452469            }
453470
454-             final  Map <String , Object > docMap  = indexRequest .sourceAsMap ();
471+             final  Map <String , Object > docMap  = indexRequest .getIndexRequest (). sourceAsMap ();
455472            long  inputLength  = 0 ;
456473            for  (var  entry  : fieldInferenceMap .values ()) {
457474                String  field  = entry .getName ();
@@ -487,6 +504,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
487504                         * This ensures that the field is treated as intentionally cleared, 
488505                         * preventing any unintended carryover of prior inference results. 
489506                         */ 
507+                         if  (incrementIndexingPressure (indexRequest , itemIndex ) == false ) {
508+                             return  inputLength ;
509+                         }
510+ 
490511                        var  slot  = ensureResponseAccumulatorSlot (itemIndex );
491512                        slot .addOrUpdateResponse (
492513                            new  FieldInferenceResponse (field , sourceField , null , order ++, 0 , null , EMPTY_CHUNKED_INFERENCE )
@@ -508,6 +529,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
508529                        }
509530                        continue ;
510531                    }
532+ 
511533                    var  slot  = ensureResponseAccumulatorSlot (itemIndex );
512534                    final  List <String > values ;
513535                    try  {
@@ -525,7 +547,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
525547                    List <FieldInferenceRequest > requests  = requestsMap .computeIfAbsent (inferenceId , k  -> new  ArrayList <>());
526548                    int  offsetAdjustment  = 0 ;
527549                    for  (String  v  : values ) {
528-                         inputLength  += v .length ();
550+                         if  (incrementIndexingPressure (indexRequest , itemIndex ) == false ) {
551+                             return  inputLength ;
552+                         }
553+ 
529554                        if  (v .isBlank ()) {
530555                            slot .addOrUpdateResponse (
531556                                new  FieldInferenceResponse (field , sourceField , v , order ++, 0 , null , EMPTY_CHUNKED_INFERENCE )
@@ -534,6 +559,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
534559                            requests .add (
535560                                new  FieldInferenceRequest (itemIndex , field , sourceField , v , order ++, offsetAdjustment , chunkingSettings )
536561                            );
562+                             inputLength  += v .length ();
537563                        }
538564
539565                        // When using the inference metadata fields format, all the input values are concatenated so that the 
@@ -543,9 +569,54 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
543569                    }
544570                }
545571            }
572+ 
546573            return  inputLength ;
547574        }
548575
576+         private  static  class  IndexRequestWithIndexingPressure  {
577+             private  final  IndexRequest  indexRequest ;
578+             private  boolean  indexingPressureIncremented ;
579+ 
580+             private  IndexRequestWithIndexingPressure (IndexRequest  indexRequest ) {
581+                 this .indexRequest  = indexRequest ;
582+                 this .indexingPressureIncremented  = false ;
583+             }
584+ 
585+             private  IndexRequest  getIndexRequest () {
586+                 return  indexRequest ;
587+             }
588+ 
589+             private  boolean  isIndexingPressureIncremented () {
590+                 return  indexingPressureIncremented ;
591+             }
592+ 
593+             private  void  setIndexingPressureIncremented () {
594+                 this .indexingPressureIncremented  = true ;
595+             }
596+         }
597+ 
598+         private  boolean  incrementIndexingPressure (IndexRequestWithIndexingPressure  indexRequest , int  itemIndex ) {
599+             boolean  success  = true ;
600+             if  (indexRequest .isIndexingPressureIncremented () == false ) {
601+                 try  {
602+                     // Track operation count as one operation per document source update 
603+                     coordinatingIndexingPressure .increment (1 , indexRequest .getIndexRequest ().source ().ramBytesUsed ());
604+                     indexRequest .setIndexingPressureIncremented ();
605+                 } catch  (EsRejectedExecutionException  e ) {
606+                     addInferenceResponseFailure (
607+                         itemIndex ,
608+                         new  InferenceException (
609+                             "Insufficient memory available to update source on document ["  + indexRequest .getIndexRequest ().id () + "]" ,
610+                             e 
611+                         )
612+                     );
613+                     success  = false ;
614+                 }
615+             }
616+ 
617+             return  success ;
618+         }
619+ 
549620        private  FieldInferenceResponseAccumulator  ensureResponseAccumulatorSlot (int  id ) {
550621            FieldInferenceResponseAccumulator  acc  = inferenceResults .get (id );
551622            if  (acc  == null ) {
@@ -622,6 +693,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
622693                inferenceFieldsMap .put (fieldName , result );
623694            }
624695
696+             BytesReference  originalSource  = indexRequest .source ();
625697            if  (useLegacyFormat ) {
626698                var  newDocMap  = indexRequest .sourceAsMap ();
627699                for  (var  entry  : inferenceFieldsMap .entrySet ()) {
@@ -634,6 +706,23 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
634706                    indexRequest .source (builder );
635707                }
636708            }
709+             long  modifiedSourceSize  = indexRequest .source ().ramBytesUsed ();
710+ 
711+             // Add the indexing pressure from the source modifications. 
712+             // Don't increment operation count because we count one source update as one operation, and we already accounted for those 
713+             // in addFieldInferenceRequests. 
714+             try  {
715+                 coordinatingIndexingPressure .increment (0 , modifiedSourceSize  - originalSource .ramBytesUsed ());
716+             } catch  (EsRejectedExecutionException  e ) {
717+                 indexRequest .source (originalSource , indexRequest .getContentType ());
718+                 item .abort (
719+                     item .index (),
720+                     new  InferenceException (
721+                         "Insufficient memory available to insert inference results into document ["  + indexRequest .id () + "]" ,
722+                         e 
723+                     )
724+                 );
725+             }
637726        }
638727    }
639728
0 commit comments