2828import org .elasticsearch .common .settings .Setting ;
2929import org .elasticsearch .common .unit .ByteSizeValue ;
3030import org .elasticsearch .common .util .concurrent .AtomicArray ;
31+ import org .elasticsearch .common .util .concurrent .EsRejectedExecutionException ;
3132import org .elasticsearch .common .xcontent .XContentHelper ;
3233import org .elasticsearch .common .xcontent .support .XContentMapValues ;
3334import org .elasticsearch .core .Nullable ;
3435import org .elasticsearch .core .Releasable ;
3536import org .elasticsearch .core .TimeValue ;
37+ import org .elasticsearch .index .IndexingPressure ;
3638import org .elasticsearch .index .mapper .InferenceMetadataFieldsMapper ;
3739import org .elasticsearch .inference .ChunkInferenceInput ;
3840import org .elasticsearch .inference .ChunkedInference ;
@@ -108,18 +110,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
108110 private final InferenceServiceRegistry inferenceServiceRegistry ;
109111 private final ModelRegistry modelRegistry ;
110112 private final XPackLicenseState licenseState ;
113+ private final IndexingPressure indexingPressure ;
111114 private volatile long batchSizeInBytes ;
112115
113116 public ShardBulkInferenceActionFilter (
114117 ClusterService clusterService ,
115118 InferenceServiceRegistry inferenceServiceRegistry ,
116119 ModelRegistry modelRegistry ,
117- XPackLicenseState licenseState
120+ XPackLicenseState licenseState ,
121+ IndexingPressure indexingPressure
118122 ) {
119123 this .clusterService = clusterService ;
120124 this .inferenceServiceRegistry = inferenceServiceRegistry ;
121125 this .modelRegistry = modelRegistry ;
122126 this .licenseState = licenseState ;
127+ this .indexingPressure = indexingPressure ;
123128 this .batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE .get (clusterService .getSettings ()).getBytes ();
124129 clusterService .getClusterSettings ().addSettingsUpdateConsumer (INDICES_INFERENCE_BATCH_SIZE , this ::setBatchSize );
125130 }
@@ -145,8 +150,15 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
145150 BulkShardRequest bulkShardRequest = (BulkShardRequest ) request ;
146151 var fieldInferenceMetadata = bulkShardRequest .consumeInferenceFieldMap ();
147152 if (fieldInferenceMetadata != null && fieldInferenceMetadata .isEmpty () == false ) {
148- Runnable onInferenceCompletion = () -> chain .proceed (task , action , request , listener );
149- processBulkShardRequest (fieldInferenceMetadata , bulkShardRequest , onInferenceCompletion );
153+ // Maintain coordinating indexing pressure from inference until the indexing operations are complete
154+ IndexingPressure .Coordinating coordinatingIndexingPressure = indexingPressure .createCoordinatingOperation (false );
155+ Runnable onInferenceCompletion = () -> chain .proceed (
156+ task ,
157+ action ,
158+ request ,
159+ ActionListener .releaseAfter (listener , coordinatingIndexingPressure )
160+ );
161+ processBulkShardRequest (fieldInferenceMetadata , bulkShardRequest , onInferenceCompletion , coordinatingIndexingPressure );
150162 return ;
151163 }
152164 }
@@ -156,11 +168,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
156168 private void processBulkShardRequest (
157169 Map <String , InferenceFieldMetadata > fieldInferenceMap ,
158170 BulkShardRequest bulkShardRequest ,
159- Runnable onCompletion
171+ Runnable onCompletion ,
172+ IndexingPressure .Coordinating coordinatingIndexingPressure
160173 ) {
161174 var index = clusterService .state ().getMetadata ().index (bulkShardRequest .index ());
162175 boolean useLegacyFormat = InferenceMetadataFieldsMapper .isEnabled (index .getSettings ()) == false ;
163- new AsyncBulkShardInferenceAction (useLegacyFormat , fieldInferenceMap , bulkShardRequest , onCompletion ).run ();
176+ new AsyncBulkShardInferenceAction (useLegacyFormat , fieldInferenceMap , bulkShardRequest , onCompletion , coordinatingIndexingPressure )
177+ .run ();
164178 }
165179
166180 private record InferenceProvider (InferenceService service , Model model ) {}
@@ -230,18 +244,21 @@ private class AsyncBulkShardInferenceAction implements Runnable {
230244 private final BulkShardRequest bulkShardRequest ;
231245 private final Runnable onCompletion ;
232246 private final AtomicArray <FieldInferenceResponseAccumulator > inferenceResults ;
247+ private final IndexingPressure .Coordinating coordinatingIndexingPressure ;
233248
234249 private AsyncBulkShardInferenceAction (
235250 boolean useLegacyFormat ,
236251 Map <String , InferenceFieldMetadata > fieldInferenceMap ,
237252 BulkShardRequest bulkShardRequest ,
238- Runnable onCompletion
253+ Runnable onCompletion ,
254+ IndexingPressure .Coordinating coordinatingIndexingPressure
239255 ) {
240256 this .useLegacyFormat = useLegacyFormat ;
241257 this .fieldInferenceMap = fieldInferenceMap ;
242258 this .bulkShardRequest = bulkShardRequest ;
243259 this .inferenceResults = new AtomicArray <>(bulkShardRequest .items ().length );
244260 this .onCompletion = onCompletion ;
261+ this .coordinatingIndexingPressure = coordinatingIndexingPressure ;
245262 }
246263
247264 @ Override
@@ -429,9 +446,9 @@ public void onFailure(Exception exc) {
429446 */
430447 private long addFieldInferenceRequests (BulkItemRequest item , int itemIndex , Map <String , List <FieldInferenceRequest >> requestsMap ) {
431448 boolean isUpdateRequest = false ;
432- final IndexRequest indexRequest ;
449+ final IndexRequestWithIndexingPressure indexRequest ;
433450 if (item .request () instanceof IndexRequest ir ) {
434- indexRequest = ir ;
451+ indexRequest = new IndexRequestWithIndexingPressure ( ir ) ;
435452 } else if (item .request () instanceof UpdateRequest updateRequest ) {
436453 isUpdateRequest = true ;
437454 if (updateRequest .script () != null ) {
@@ -445,13 +462,13 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
445462 );
446463 return 0 ;
447464 }
448- indexRequest = updateRequest .doc ();
465+ indexRequest = new IndexRequestWithIndexingPressure ( updateRequest .doc () );
449466 } else {
450467 // ignore delete request
451468 return 0 ;
452469 }
453470
454- final Map <String , Object > docMap = indexRequest .sourceAsMap ();
471+ final Map <String , Object > docMap = indexRequest .getIndexRequest (). sourceAsMap ();
455472 long inputLength = 0 ;
456473 for (var entry : fieldInferenceMap .values ()) {
457474 String field = entry .getName ();
@@ -487,6 +504,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
487504 * This ensures that the field is treated as intentionally cleared,
488505 * preventing any unintended carryover of prior inference results.
489506 */
507+ if (incrementIndexingPressure (indexRequest , itemIndex ) == false ) {
508+ return inputLength ;
509+ }
510+
490511 var slot = ensureResponseAccumulatorSlot (itemIndex );
491512 slot .addOrUpdateResponse (
492513 new FieldInferenceResponse (field , sourceField , null , order ++, 0 , null , EMPTY_CHUNKED_INFERENCE )
@@ -508,6 +529,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
508529 }
509530 continue ;
510531 }
532+
511533 var slot = ensureResponseAccumulatorSlot (itemIndex );
512534 final List <String > values ;
513535 try {
@@ -525,7 +547,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
525547 List <FieldInferenceRequest > requests = requestsMap .computeIfAbsent (inferenceId , k -> new ArrayList <>());
526548 int offsetAdjustment = 0 ;
527549 for (String v : values ) {
528- inputLength += v .length ();
550+ if (incrementIndexingPressure (indexRequest , itemIndex ) == false ) {
551+ return inputLength ;
552+ }
553+
529554 if (v .isBlank ()) {
530555 slot .addOrUpdateResponse (
531556 new FieldInferenceResponse (field , sourceField , v , order ++, 0 , null , EMPTY_CHUNKED_INFERENCE )
@@ -534,6 +559,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
534559 requests .add (
535560 new FieldInferenceRequest (itemIndex , field , sourceField , v , order ++, offsetAdjustment , chunkingSettings )
536561 );
562+ inputLength += v .length ();
537563 }
538564
539565 // When using the inference metadata fields format, all the input values are concatenated so that the
@@ -543,9 +569,54 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
543569 }
544570 }
545571 }
572+
546573 return inputLength ;
547574 }
548575
576+ private static class IndexRequestWithIndexingPressure {
577+ private final IndexRequest indexRequest ;
578+ private boolean indexingPressureIncremented ;
579+
580+ private IndexRequestWithIndexingPressure (IndexRequest indexRequest ) {
581+ this .indexRequest = indexRequest ;
582+ this .indexingPressureIncremented = false ;
583+ }
584+
585+ private IndexRequest getIndexRequest () {
586+ return indexRequest ;
587+ }
588+
589+ private boolean isIndexingPressureIncremented () {
590+ return indexingPressureIncremented ;
591+ }
592+
593+ private void setIndexingPressureIncremented () {
594+ this .indexingPressureIncremented = true ;
595+ }
596+ }
597+
598+ private boolean incrementIndexingPressure (IndexRequestWithIndexingPressure indexRequest , int itemIndex ) {
599+ boolean success = true ;
600+ if (indexRequest .isIndexingPressureIncremented () == false ) {
601+ try {
602+ // Track operation count as one operation per document source update
603+ coordinatingIndexingPressure .increment (1 , indexRequest .getIndexRequest ().source ().ramBytesUsed ());
604+ indexRequest .setIndexingPressureIncremented ();
605+ } catch (EsRejectedExecutionException e ) {
606+ addInferenceResponseFailure (
607+ itemIndex ,
608+ new InferenceException (
609+ "Insufficient memory available to update source on document [" + indexRequest .getIndexRequest ().id () + "]" ,
610+ e
611+ )
612+ );
613+ success = false ;
614+ }
615+ }
616+
617+ return success ;
618+ }
619+
549620 private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot (int id ) {
550621 FieldInferenceResponseAccumulator acc = inferenceResults .get (id );
551622 if (acc == null ) {
@@ -622,6 +693,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
622693 inferenceFieldsMap .put (fieldName , result );
623694 }
624695
696+ BytesReference originalSource = indexRequest .source ();
625697 if (useLegacyFormat ) {
626698 var newDocMap = indexRequest .sourceAsMap ();
627699 for (var entry : inferenceFieldsMap .entrySet ()) {
@@ -634,6 +706,23 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
634706 indexRequest .source (builder );
635707 }
636708 }
709+ long modifiedSourceSize = indexRequest .source ().ramBytesUsed ();
710+
711+ // Add the indexing pressure from the source modifications.
712+ // Don't increment operation count because we count one source update as one operation, and we already accounted for those
713+ // in addFieldInferenceRequests.
714+ try {
715+ coordinatingIndexingPressure .increment (0 , modifiedSourceSize - originalSource .ramBytesUsed ());
716+ } catch (EsRejectedExecutionException e ) {
717+ indexRequest .source (originalSource , indexRequest .getContentType ());
718+ item .abort (
719+ item .index (),
720+ new InferenceException (
721+ "Insufficient memory available to insert inference results into document [" + indexRequest .id () + "]" ,
722+ e
723+ )
724+ );
725+ }
637726 }
638727 }
639728
0 commit comments