77
88package org .elasticsearch .xpack .inference .action .filter ;
99
10- import org .apache .lucene .util .SetOnce ;
1110import org .elasticsearch .ElasticsearchStatusException ;
1211import org .elasticsearch .ExceptionsHelper ;
1312import org .elasticsearch .ResourceNotFoundException ;
@@ -112,20 +111,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
112111 private final InferenceServiceRegistry inferenceServiceRegistry ;
113112 private final ModelRegistry modelRegistry ;
114113 private final XPackLicenseState licenseState ;
114+ private final IndexingPressure indexingPressure ;
115115 private volatile long batchSizeInBytes ;
116116
117- private final SetOnce <IndexingPressure > indexingPressure = new SetOnce <>();
118-
119117 public ShardBulkInferenceActionFilter (
120118 ClusterService clusterService ,
121119 InferenceServiceRegistry inferenceServiceRegistry ,
122120 ModelRegistry modelRegistry ,
123- XPackLicenseState licenseState
121+ XPackLicenseState licenseState ,
122+ IndexingPressure indexingPressure
124123 ) {
125124 this .clusterService = clusterService ;
126125 this .inferenceServiceRegistry = inferenceServiceRegistry ;
127126 this .modelRegistry = modelRegistry ;
128127 this .licenseState = licenseState ;
128+ this .indexingPressure = indexingPressure ;
129129 this .batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE .get (clusterService .getSettings ()).getBytes ();
130130 clusterService .getClusterSettings ().addSettingsUpdateConsumer (INDICES_INFERENCE_BATCH_SIZE , this ::setBatchSize );
131131 }
@@ -134,10 +134,6 @@ private void setBatchSize(ByteSizeValue newBatchSize) {
134134 batchSizeInBytes = newBatchSize .getBytes ();
135135 }
136136
137- public void setIndexingPressure (IndexingPressure indexingPressure ) {
138- this .indexingPressure .set (indexingPressure );
139- }
140-
141137 @ Override
142138 public String actionName () {
143139 return TransportShardBulkAction .ACTION_NAME ;
@@ -156,14 +152,14 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
156152 var fieldInferenceMetadata = bulkShardRequest .consumeInferenceFieldMap ();
157153 if (fieldInferenceMetadata != null && fieldInferenceMetadata .isEmpty () == false ) {
158154 // Maintain coordinating indexing pressure from inference until the indexing operations are complete
159- CoordinatingIndexingPressureWrapper coordinatingWrapper = startCoordinatingOperations ( );
155+ IndexingPressure . Coordinating coordinatingIndexingPressure = indexingPressure . createCoordinatingOperation ( false );
160156 Runnable onInferenceCompletion = () -> chain .proceed (
161157 task ,
162158 action ,
163159 request ,
164- ActionListener .releaseAfter (listener , coordinatingWrapper )
160+ ActionListener .releaseAfter (listener , coordinatingIndexingPressure )
165161 );
166- processBulkShardRequest (fieldInferenceMetadata , bulkShardRequest , onInferenceCompletion , coordinatingWrapper );
162+ processBulkShardRequest (fieldInferenceMetadata , bulkShardRequest , onInferenceCompletion , coordinatingIndexingPressure );
167163 return ;
168164 }
169165 }
@@ -174,22 +170,13 @@ private void processBulkShardRequest(
174170 Map <String , InferenceFieldMetadata > fieldInferenceMap ,
175171 BulkShardRequest bulkShardRequest ,
176172 Runnable onCompletion ,
177- CoordinatingIndexingPressureWrapper coordinatingWrapper
173+ IndexingPressure . Coordinating coordinatingIndexingPressure
178174 ) {
179175 final ProjectMetadata project = clusterService .state ().getMetadata ().getProject ();
180176 var index = project .index (bulkShardRequest .index ());
181177 boolean useLegacyFormat = InferenceMetadataFieldsMapper .isEnabled (index .getSettings ()) == false ;
182- new AsyncBulkShardInferenceAction (useLegacyFormat , fieldInferenceMap , bulkShardRequest , onCompletion , coordinatingWrapper ).run ();
183- }
184-
185- private CoordinatingIndexingPressureWrapper startCoordinatingOperations () {
186- IndexingPressure .Coordinating coordinating = null ;
187- IndexingPressure localIndexingPressure = indexingPressure .get ();
188- if (localIndexingPressure != null ) {
189- coordinating = localIndexingPressure .createCoordinatingOperation (false );
190- }
191-
192- return new CoordinatingIndexingPressureWrapper (coordinating );
178+ new AsyncBulkShardInferenceAction (useLegacyFormat , fieldInferenceMap , bulkShardRequest , onCompletion , coordinatingIndexingPressure )
179+ .run ();
193180 }
194181
195182 private record InferenceProvider (InferenceService service , Model model ) {}
@@ -259,21 +246,21 @@ private class AsyncBulkShardInferenceAction implements Runnable {
259246 private final BulkShardRequest bulkShardRequest ;
260247 private final Runnable onCompletion ;
261248 private final AtomicArray <FieldInferenceResponseAccumulator > inferenceResults ;
262- private final CoordinatingIndexingPressureWrapper coordinatingWrapper ;
249+ private final IndexingPressure . Coordinating coordinatingIndexingPressure ;
263250
264251 private AsyncBulkShardInferenceAction (
265252 boolean useLegacyFormat ,
266253 Map <String , InferenceFieldMetadata > fieldInferenceMap ,
267254 BulkShardRequest bulkShardRequest ,
268255 Runnable onCompletion ,
269- CoordinatingIndexingPressureWrapper coordinatingWrapper
256+ IndexingPressure . Coordinating coordinatingIndexingPressure
270257 ) {
271258 this .useLegacyFormat = useLegacyFormat ;
272259 this .fieldInferenceMap = fieldInferenceMap ;
273260 this .bulkShardRequest = bulkShardRequest ;
274261 this .inferenceResults = new AtomicArray <>(bulkShardRequest .items ().length );
275262 this .onCompletion = onCompletion ;
276- this .coordinatingWrapper = coordinatingWrapper ;
263+ this .coordinatingIndexingPressure = coordinatingIndexingPressure ;
277264 }
278265
279266 @ Override
@@ -612,8 +599,7 @@ private void setIndexingPressureIncremented() {
612599
613600 private boolean incrementIndexingPressure (IndexRequestWithIndexingPressure indexRequest , int itemIndex ) {
614601 boolean success = true ;
615- IndexingPressure .Coordinating coordinatingIndexingPressure = coordinatingWrapper .coordinating ();
616- if (coordinatingIndexingPressure != null && indexRequest .isIndexingPressureIncremented () == false ) {
602+ if (indexRequest .isIndexingPressureIncremented () == false ) {
617603 try {
618604 // Track operation count as one operation per document source update
619605 coordinatingIndexingPressure .increment (1 , indexRequest .getIndexRequest ().source ().ramBytesUsed ());
@@ -724,23 +710,20 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
724710 }
725711 long modifiedSourceSize = indexRequest .source ().ramBytesUsed ();
726712
727- IndexingPressure .Coordinating coordinatingIndexingPressure = coordinatingWrapper .coordinating ();
728- if (coordinatingIndexingPressure != null ) {
729- // Add the indexing pressure from the source modifications.
730- // Don't increment operation count because we count one source update as one operation, and we already accounted for those
731- // in addFieldInferenceRequests.
732- try {
733- coordinatingIndexingPressure .increment (0 , modifiedSourceSize - originalSource .ramBytesUsed ());
734- } catch (EsRejectedExecutionException e ) {
735- indexRequest .source (originalSource , indexRequest .getContentType ());
736- item .abort (
737- item .index (),
738- new InferenceException (
739- "Insufficient memory available to insert inference results into document [" + indexRequest .id () + "]" ,
740- e
741- )
742- );
743- }
713+ // Add the indexing pressure from the source modifications.
714+ // Don't increment operation count because we count one source update as one operation, and we already accounted for those
715+ // in addFieldInferenceRequests.
716+ try {
717+ coordinatingIndexingPressure .increment (0 , modifiedSourceSize - originalSource .ramBytesUsed ());
718+ } catch (EsRejectedExecutionException e ) {
719+ indexRequest .source (originalSource , indexRequest .getContentType ());
720+ item .abort (
721+ item .index (),
722+ new InferenceException (
723+ "Insufficient memory available to insert inference results into document [" + indexRequest .id () + "]" ,
724+ e
725+ )
726+ );
744727 }
745728 }
746729 }
@@ -791,13 +774,4 @@ public Iterator<Chunk> chunksAsByteReference(XContent xcontent) {
791774 return Collections .emptyIterator ();
792775 }
793776 }
794-
795- private record CoordinatingIndexingPressureWrapper (@ Nullable IndexingPressure .Coordinating coordinating ) implements Releasable {
796- @ Override
797- public void close () {
798- if (coordinating != null ) {
799- coordinating .close ();
800- }
801- }
802- }
803777}
0 commit comments