2929import org .elasticsearch .common .settings .Setting ;
3030import org .elasticsearch .common .unit .ByteSizeValue ;
3131import org .elasticsearch .common .util .concurrent .AtomicArray ;
32+ import org .elasticsearch .common .util .concurrent .EsRejectedExecutionException ;
3233import org .elasticsearch .common .xcontent .XContentHelper ;
3334import org .elasticsearch .common .xcontent .support .XContentMapValues ;
3435import org .elasticsearch .core .Nullable ;
3536import org .elasticsearch .core .Releasable ;
3637import org .elasticsearch .core .TimeValue ;
38+ import org .elasticsearch .index .IndexingPressure ;
3739import org .elasticsearch .index .mapper .InferenceMetadataFieldsMapper ;
3840import org .elasticsearch .inference .ChunkInferenceInput ;
3941import org .elasticsearch .inference .ChunkedInference ;
@@ -109,18 +111,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
109111 private final InferenceServiceRegistry inferenceServiceRegistry ;
110112 private final ModelRegistry modelRegistry ;
111113 private final XPackLicenseState licenseState ;
114+ private final IndexingPressure indexingPressure ;
112115 private volatile long batchSizeInBytes ;
113116
114117 public ShardBulkInferenceActionFilter (
115118 ClusterService clusterService ,
116119 InferenceServiceRegistry inferenceServiceRegistry ,
117120 ModelRegistry modelRegistry ,
118- XPackLicenseState licenseState
121+ XPackLicenseState licenseState ,
122+ IndexingPressure indexingPressure
119123 ) {
120124 this .clusterService = clusterService ;
121125 this .inferenceServiceRegistry = inferenceServiceRegistry ;
122126 this .modelRegistry = modelRegistry ;
123127 this .licenseState = licenseState ;
128+ this .indexingPressure = indexingPressure ;
124129 this .batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE .get (clusterService .getSettings ()).getBytes ();
125130 clusterService .getClusterSettings ().addSettingsUpdateConsumer (INDICES_INFERENCE_BATCH_SIZE , this ::setBatchSize );
126131 }
@@ -146,8 +151,15 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
146151 BulkShardRequest bulkShardRequest = (BulkShardRequest ) request ;
147152 var fieldInferenceMetadata = bulkShardRequest .consumeInferenceFieldMap ();
148153 if (fieldInferenceMetadata != null && fieldInferenceMetadata .isEmpty () == false ) {
149- Runnable onInferenceCompletion = () -> chain .proceed (task , action , request , listener );
150- processBulkShardRequest (fieldInferenceMetadata , bulkShardRequest , onInferenceCompletion );
154+ // Maintain coordinating indexing pressure from inference until the indexing operations are complete
155+ IndexingPressure .Coordinating coordinatingIndexingPressure = indexingPressure .createCoordinatingOperation (false );
156+ Runnable onInferenceCompletion = () -> chain .proceed (
157+ task ,
158+ action ,
159+ request ,
160+ ActionListener .releaseAfter (listener , coordinatingIndexingPressure )
161+ );
162+ processBulkShardRequest (fieldInferenceMetadata , bulkShardRequest , onInferenceCompletion , coordinatingIndexingPressure );
151163 return ;
152164 }
153165 }
@@ -157,12 +169,14 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
157169 private void processBulkShardRequest (
158170 Map <String , InferenceFieldMetadata > fieldInferenceMap ,
159171 BulkShardRequest bulkShardRequest ,
160- Runnable onCompletion
172+ Runnable onCompletion ,
173+ IndexingPressure .Coordinating coordinatingIndexingPressure
161174 ) {
162175 final ProjectMetadata project = clusterService .state ().getMetadata ().getProject ();
163176 var index = project .index (bulkShardRequest .index ());
164177 boolean useLegacyFormat = InferenceMetadataFieldsMapper .isEnabled (index .getSettings ()) == false ;
165- new AsyncBulkShardInferenceAction (useLegacyFormat , fieldInferenceMap , bulkShardRequest , onCompletion ).run ();
178+ new AsyncBulkShardInferenceAction (useLegacyFormat , fieldInferenceMap , bulkShardRequest , onCompletion , coordinatingIndexingPressure )
179+ .run ();
166180 }
167181
168182 private record InferenceProvider (InferenceService service , Model model ) {}
@@ -232,18 +246,21 @@ private class AsyncBulkShardInferenceAction implements Runnable {
232246 private final BulkShardRequest bulkShardRequest ;
233247 private final Runnable onCompletion ;
234248 private final AtomicArray <FieldInferenceResponseAccumulator > inferenceResults ;
249+ private final IndexingPressure .Coordinating coordinatingIndexingPressure ;
235250
236251 private AsyncBulkShardInferenceAction (
237252 boolean useLegacyFormat ,
238253 Map <String , InferenceFieldMetadata > fieldInferenceMap ,
239254 BulkShardRequest bulkShardRequest ,
240- Runnable onCompletion
255+ Runnable onCompletion ,
256+ IndexingPressure .Coordinating coordinatingIndexingPressure
241257 ) {
242258 this .useLegacyFormat = useLegacyFormat ;
243259 this .fieldInferenceMap = fieldInferenceMap ;
244260 this .bulkShardRequest = bulkShardRequest ;
245261 this .inferenceResults = new AtomicArray <>(bulkShardRequest .items ().length );
246262 this .onCompletion = onCompletion ;
263+ this .coordinatingIndexingPressure = coordinatingIndexingPressure ;
247264 }
248265
249266 @ Override
@@ -431,9 +448,9 @@ public void onFailure(Exception exc) {
431448 */
432449 private long addFieldInferenceRequests (BulkItemRequest item , int itemIndex , Map <String , List <FieldInferenceRequest >> requestsMap ) {
433450 boolean isUpdateRequest = false ;
434- final IndexRequest indexRequest ;
451+ final IndexRequestWithIndexingPressure indexRequest ;
435452 if (item .request () instanceof IndexRequest ir ) {
436- indexRequest = ir ;
453+ indexRequest = new IndexRequestWithIndexingPressure ( ir ) ;
437454 } else if (item .request () instanceof UpdateRequest updateRequest ) {
438455 isUpdateRequest = true ;
439456 if (updateRequest .script () != null ) {
@@ -447,13 +464,13 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
447464 );
448465 return 0 ;
449466 }
450- indexRequest = updateRequest .doc ();
467+ indexRequest = new IndexRequestWithIndexingPressure ( updateRequest .doc () );
451468 } else {
452469 // ignore delete request
453470 return 0 ;
454471 }
455472
456- final Map <String , Object > docMap = indexRequest .sourceAsMap ();
473+ final Map <String , Object > docMap = indexRequest .getIndexRequest (). sourceAsMap ();
457474 long inputLength = 0 ;
458475 for (var entry : fieldInferenceMap .values ()) {
459476 String field = entry .getName ();
@@ -489,6 +506,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
489506 * This ensures that the field is treated as intentionally cleared,
490507 * preventing any unintended carryover of prior inference results.
491508 */
509+ if (incrementIndexingPressure (indexRequest , itemIndex ) == false ) {
510+ return inputLength ;
511+ }
512+
492513 var slot = ensureResponseAccumulatorSlot (itemIndex );
493514 slot .addOrUpdateResponse (
494515 new FieldInferenceResponse (field , sourceField , null , order ++, 0 , null , EMPTY_CHUNKED_INFERENCE )
@@ -510,6 +531,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
510531 }
511532 continue ;
512533 }
534+
513535 var slot = ensureResponseAccumulatorSlot (itemIndex );
514536 final List <String > values ;
515537 try {
@@ -527,7 +549,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
527549 List <FieldInferenceRequest > requests = requestsMap .computeIfAbsent (inferenceId , k -> new ArrayList <>());
528550 int offsetAdjustment = 0 ;
529551 for (String v : values ) {
530- inputLength += v .length ();
552+ if (incrementIndexingPressure (indexRequest , itemIndex ) == false ) {
553+ return inputLength ;
554+ }
555+
531556 if (v .isBlank ()) {
532557 slot .addOrUpdateResponse (
533558 new FieldInferenceResponse (field , sourceField , v , order ++, 0 , null , EMPTY_CHUNKED_INFERENCE )
@@ -536,6 +561,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
536561 requests .add (
537562 new FieldInferenceRequest (itemIndex , field , sourceField , v , order ++, offsetAdjustment , chunkingSettings )
538563 );
564+ inputLength += v .length ();
539565 }
540566
541567 // When using the inference metadata fields format, all the input values are concatenated so that the
@@ -545,9 +571,54 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
545571 }
546572 }
547573 }
574+
548575 return inputLength ;
549576 }
550577
578+ private static class IndexRequestWithIndexingPressure {
579+ private final IndexRequest indexRequest ;
580+ private boolean indexingPressureIncremented ;
581+
582+ private IndexRequestWithIndexingPressure (IndexRequest indexRequest ) {
583+ this .indexRequest = indexRequest ;
584+ this .indexingPressureIncremented = false ;
585+ }
586+
587+ private IndexRequest getIndexRequest () {
588+ return indexRequest ;
589+ }
590+
591+ private boolean isIndexingPressureIncremented () {
592+ return indexingPressureIncremented ;
593+ }
594+
595+ private void setIndexingPressureIncremented () {
596+ this .indexingPressureIncremented = true ;
597+ }
598+ }
599+
600+ private boolean incrementIndexingPressure (IndexRequestWithIndexingPressure indexRequest , int itemIndex ) {
601+ boolean success = true ;
602+ if (indexRequest .isIndexingPressureIncremented () == false ) {
603+ try {
604+ // Track operation count as one operation per document source update
605+ coordinatingIndexingPressure .increment (1 , indexRequest .getIndexRequest ().source ().ramBytesUsed ());
606+ indexRequest .setIndexingPressureIncremented ();
607+ } catch (EsRejectedExecutionException e ) {
608+ addInferenceResponseFailure (
609+ itemIndex ,
610+ new InferenceException (
611+ "Insufficient memory available to update source on document [" + indexRequest .getIndexRequest ().id () + "]" ,
612+ e
613+ )
614+ );
615+ success = false ;
616+ }
617+ }
618+
619+ return success ;
620+ }
621+
551622 private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot (int id ) {
552623 FieldInferenceResponseAccumulator acc = inferenceResults .get (id );
553624 if (acc == null ) {
@@ -624,6 +695,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
624695 inferenceFieldsMap .put (fieldName , result );
625696 }
626697
698+ BytesReference originalSource = indexRequest .source ();
627699 if (useLegacyFormat ) {
628700 var newDocMap = indexRequest .sourceAsMap ();
629701 for (var entry : inferenceFieldsMap .entrySet ()) {
@@ -636,6 +708,23 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
636708 indexRequest .source (builder );
637709 }
638710 }
711+ long modifiedSourceSize = indexRequest .source ().ramBytesUsed ();
712+
713+ // Add the indexing pressure from the source modifications.
714+ // Don't increment operation count because we count one source update as one operation, and we already accounted for those
715+ // in addFieldInferenceRequests.
716+ try {
717+ coordinatingIndexingPressure .increment (0 , modifiedSourceSize - originalSource .ramBytesUsed ());
718+ } catch (EsRejectedExecutionException e ) {
719+ indexRequest .source (originalSource , indexRequest .getContentType ());
720+ item .abort (
721+ item .index (),
722+ new InferenceException (
723+ "Insufficient memory available to insert inference results into document [" + indexRequest .id () + "]" ,
724+ e
725+ )
726+ );
727+ }
639728 }
640729 }
641730
0 commit comments