Skip to content

Commit ecb8e02

Browse files
committed
Pass indexing pressure in constructor
1 parent 0ff48cf commit ecb8e02

File tree

3 files changed

+71
-65
lines changed

3 files changed

+71
-65
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,13 @@ public Collection<?> createComponents(PluginServices services) {
323323
}
324324
inferenceServiceRegistry.set(serviceRegistry);
325325

326-
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState());
327-
actionFilter.setIndexingPressure(services.indexingPressure());
326+
var actionFilter = new ShardBulkInferenceActionFilter(
327+
services.clusterService(),
328+
serviceRegistry,
329+
modelRegistry,
330+
getLicenseState(),
331+
services.indexingPressure()
332+
);
328333
shardBulkInferenceActionFilter.set(actionFilter);
329334

330335
var meterRegistry = services.telemetryProvider().getMeterRegistry();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 28 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.inference.action.filter;
99

10-
import org.apache.lucene.util.SetOnce;
1110
import org.elasticsearch.ElasticsearchStatusException;
1211
import org.elasticsearch.ExceptionsHelper;
1312
import org.elasticsearch.ResourceNotFoundException;
@@ -112,20 +111,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
112111
private final InferenceServiceRegistry inferenceServiceRegistry;
113112
private final ModelRegistry modelRegistry;
114113
private final XPackLicenseState licenseState;
114+
private final IndexingPressure indexingPressure;
115115
private volatile long batchSizeInBytes;
116116

117-
private final SetOnce<IndexingPressure> indexingPressure = new SetOnce<>();
118-
119117
public ShardBulkInferenceActionFilter(
120118
ClusterService clusterService,
121119
InferenceServiceRegistry inferenceServiceRegistry,
122120
ModelRegistry modelRegistry,
123-
XPackLicenseState licenseState
121+
XPackLicenseState licenseState,
122+
IndexingPressure indexingPressure
124123
) {
125124
this.clusterService = clusterService;
126125
this.inferenceServiceRegistry = inferenceServiceRegistry;
127126
this.modelRegistry = modelRegistry;
128127
this.licenseState = licenseState;
128+
this.indexingPressure = indexingPressure;
129129
this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
130130
clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
131131
}
@@ -134,10 +134,6 @@ private void setBatchSize(ByteSizeValue newBatchSize) {
134134
batchSizeInBytes = newBatchSize.getBytes();
135135
}
136136

137-
public void setIndexingPressure(IndexingPressure indexingPressure) {
138-
this.indexingPressure.set(indexingPressure);
139-
}
140-
141137
@Override
142138
public String actionName() {
143139
return TransportShardBulkAction.ACTION_NAME;
@@ -156,14 +152,14 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
156152
var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap();
157153
if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) {
158154
// Maintain coordinating indexing pressure from inference until the indexing operations are complete
159-
CoordinatingIndexingPressureWrapper coordinatingWrapper = startCoordinatingOperations();
155+
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.createCoordinatingOperation(false);
160156
Runnable onInferenceCompletion = () -> chain.proceed(
161157
task,
162158
action,
163159
request,
164-
ActionListener.releaseAfter(listener, coordinatingWrapper)
160+
ActionListener.releaseAfter(listener, coordinatingIndexingPressure)
165161
);
166-
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion, coordinatingWrapper);
162+
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion, coordinatingIndexingPressure);
167163
return;
168164
}
169165
}
@@ -174,22 +170,13 @@ private void processBulkShardRequest(
174170
Map<String, InferenceFieldMetadata> fieldInferenceMap,
175171
BulkShardRequest bulkShardRequest,
176172
Runnable onCompletion,
177-
CoordinatingIndexingPressureWrapper coordinatingWrapper
173+
IndexingPressure.Coordinating coordinatingIndexingPressure
178174
) {
179175
final ProjectMetadata project = clusterService.state().getMetadata().getProject();
180176
var index = project.index(bulkShardRequest.index());
181177
boolean useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(index.getSettings()) == false;
182-
new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion, coordinatingWrapper).run();
183-
}
184-
185-
private CoordinatingIndexingPressureWrapper startCoordinatingOperations() {
186-
IndexingPressure.Coordinating coordinating = null;
187-
IndexingPressure localIndexingPressure = indexingPressure.get();
188-
if (localIndexingPressure != null) {
189-
coordinating = localIndexingPressure.createCoordinatingOperation(false);
190-
}
191-
192-
return new CoordinatingIndexingPressureWrapper(coordinating);
178+
new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion, coordinatingIndexingPressure)
179+
.run();
193180
}
194181

195182
private record InferenceProvider(InferenceService service, Model model) {}
@@ -259,21 +246,21 @@ private class AsyncBulkShardInferenceAction implements Runnable {
259246
private final BulkShardRequest bulkShardRequest;
260247
private final Runnable onCompletion;
261248
private final AtomicArray<FieldInferenceResponseAccumulator> inferenceResults;
262-
private final CoordinatingIndexingPressureWrapper coordinatingWrapper;
249+
private final IndexingPressure.Coordinating coordinatingIndexingPressure;
263250

264251
private AsyncBulkShardInferenceAction(
265252
boolean useLegacyFormat,
266253
Map<String, InferenceFieldMetadata> fieldInferenceMap,
267254
BulkShardRequest bulkShardRequest,
268255
Runnable onCompletion,
269-
CoordinatingIndexingPressureWrapper coordinatingWrapper
256+
IndexingPressure.Coordinating coordinatingIndexingPressure
270257
) {
271258
this.useLegacyFormat = useLegacyFormat;
272259
this.fieldInferenceMap = fieldInferenceMap;
273260
this.bulkShardRequest = bulkShardRequest;
274261
this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length);
275262
this.onCompletion = onCompletion;
276-
this.coordinatingWrapper = coordinatingWrapper;
263+
this.coordinatingIndexingPressure = coordinatingIndexingPressure;
277264
}
278265

279266
@Override
@@ -612,8 +599,7 @@ private void setIndexingPressureIncremented() {
612599

613600
private boolean incrementIndexingPressure(IndexRequestWithIndexingPressure indexRequest, int itemIndex) {
614601
boolean success = true;
615-
IndexingPressure.Coordinating coordinatingIndexingPressure = coordinatingWrapper.coordinating();
616-
if (coordinatingIndexingPressure != null && indexRequest.isIndexingPressureIncremented() == false) {
602+
if (indexRequest.isIndexingPressureIncremented() == false) {
617603
try {
618604
// Track operation count as one operation per document source update
619605
coordinatingIndexingPressure.increment(1, indexRequest.getIndexRequest().source().ramBytesUsed());
@@ -724,23 +710,20 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
724710
}
725711
long modifiedSourceSize = indexRequest.source().ramBytesUsed();
726712

727-
IndexingPressure.Coordinating coordinatingIndexingPressure = coordinatingWrapper.coordinating();
728-
if (coordinatingIndexingPressure != null) {
729-
// Add the indexing pressure from the source modifications.
730-
// Don't increment operation count because we count one source update as one operation, and we already accounted for those
731-
// in addFieldInferenceRequests.
732-
try {
733-
coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.ramBytesUsed());
734-
} catch (EsRejectedExecutionException e) {
735-
indexRequest.source(originalSource, indexRequest.getContentType());
736-
item.abort(
737-
item.index(),
738-
new InferenceException(
739-
"Insufficient memory available to insert inference results into document [" + indexRequest.id() + "]",
740-
e
741-
)
742-
);
743-
}
713+
// Add the indexing pressure from the source modifications.
714+
// Don't increment operation count because we count one source update as one operation, and we already accounted for those
715+
// in addFieldInferenceRequests.
716+
try {
717+
coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.ramBytesUsed());
718+
} catch (EsRejectedExecutionException e) {
719+
indexRequest.source(originalSource, indexRequest.getContentType());
720+
item.abort(
721+
item.index(),
722+
new InferenceException(
723+
"Insufficient memory available to insert inference results into document [" + indexRequest.id() + "]",
724+
e
725+
)
726+
);
744727
}
745728
}
746729
}
@@ -791,13 +774,4 @@ public Iterator<Chunk> chunksAsByteReference(XContent xcontent) {
791774
return Collections.emptyIterator();
792775
}
793776
}
794-
795-
private record CoordinatingIndexingPressureWrapper(@Nullable IndexingPressure.Coordinating coordinating) implements Releasable {
796-
@Override
797-
public void close() {
798-
if (coordinating != null) {
799-
coordinating.close();
800-
}
801-
}
802-
}
803777
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115

116116
public class ShardBulkInferenceActionFilterTests extends ESTestCase {
117117
private static final Object EXPLICIT_NULL = new Object();
118+
private static final IndexingPressure NOOP_INDEXING_PRESSURE = new NoopIndexingPressure();
118119

119120
private final boolean useLegacyFormat;
120121
private ThreadPool threadPool;
@@ -140,7 +141,7 @@ public void tearDownThreadPool() throws Exception {
140141

141142
@SuppressWarnings({ "unchecked", "rawtypes" })
142143
public void testFilterNoop() throws Exception {
143-
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true);
144+
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, true);
144145
CountDownLatch chainExecuted = new CountDownLatch(1);
145146
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
146147
try {
@@ -166,7 +167,7 @@ public void testFilterNoop() throws Exception {
166167
@SuppressWarnings({ "unchecked", "rawtypes" })
167168
public void testLicenseInvalidForInference() throws InterruptedException {
168169
StaticModel model = StaticModel.createRandomInstance();
169-
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false);
170+
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, false);
170171
CountDownLatch chainExecuted = new CountDownLatch(1);
171172
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
172173
try {
@@ -207,6 +208,7 @@ public void testInferenceNotFound() throws Exception {
207208
ShardBulkInferenceActionFilter filter = createFilter(
208209
threadPool,
209210
Map.of(model.getInferenceEntityId(), model),
211+
NOOP_INDEXING_PRESSURE,
210212
useLegacyFormat,
211213
true
212214
);
@@ -253,6 +255,7 @@ public void testItemFailures() throws Exception {
253255
ShardBulkInferenceActionFilter filter = createFilter(
254256
threadPool,
255257
Map.of(model.getInferenceEntityId(), model),
258+
NOOP_INDEXING_PRESSURE,
256259
useLegacyFormat,
257260
true
258261
);
@@ -323,6 +326,7 @@ public void testExplicitNull() throws Exception {
323326
ShardBulkInferenceActionFilter filter = createFilter(
324327
threadPool,
325328
Map.of(model.getInferenceEntityId(), model),
329+
NOOP_INDEXING_PRESSURE,
326330
useLegacyFormat,
327331
true
328332
);
@@ -393,6 +397,7 @@ public void testHandleEmptyInput() throws Exception {
393397
ShardBulkInferenceActionFilter filter = createFilter(
394398
threadPool,
395399
Map.of(model.getInferenceEntityId(), model),
400+
NOOP_INDEXING_PRESSURE,
396401
useLegacyFormat,
397402
true
398403
);
@@ -465,7 +470,7 @@ public void testManyRandomDocs() throws Exception {
465470
modifiedRequests[id] = res[1];
466471
}
467472

468-
ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true);
473+
ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, NOOP_INDEXING_PRESSURE, useLegacyFormat, true);
469474
CountDownLatch chainExecuted = new CountDownLatch(1);
470475
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
471476
try {
@@ -504,10 +509,10 @@ public void testIndexingPressure() throws Exception {
504509
final ShardBulkInferenceActionFilter filter = createFilter(
505510
threadPool,
506511
Map.of(sparseModel.getInferenceEntityId(), sparseModel, denseModel.getInferenceEntityId(), denseModel),
512+
indexingPressure,
507513
useLegacyFormat,
508514
true
509515
);
510-
filter.setIndexingPressure(indexingPressure);
511516

512517
XContentBuilder doc0Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "a test value");
513518
XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "dense_field", "another test value");
@@ -621,10 +626,10 @@ public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Excep
621626
final ShardBulkInferenceActionFilter filter = createFilter(
622627
threadPool,
623628
Map.of(sparseModel.getInferenceEntityId(), sparseModel),
629+
indexingPressure,
624630
useLegacyFormat,
625631
true
626632
);
627-
filter.setIndexingPressure(indexingPressure);
628633

629634
XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
630635

@@ -703,10 +708,10 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except
703708
final ShardBulkInferenceActionFilter filter = createFilter(
704709
threadPool,
705710
Map.of(sparseModel.getInferenceEntityId(), sparseModel),
711+
indexingPressure,
706712
useLegacyFormat,
707713
true
708714
);
709-
filter.setIndexingPressure(indexingPressure);
710715

711716
CountDownLatch chainExecuted = new CountDownLatch(1);
712717
ActionFilterChain<BulkShardRequest, BulkShardResponse> actionFilterChain = (task, action, request, listener) -> {
@@ -775,6 +780,7 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except
775780
private static ShardBulkInferenceActionFilter createFilter(
776781
ThreadPool threadPool,
777782
Map<String, StaticModel> modelMap,
783+
IndexingPressure indexingPressure,
778784
boolean useLegacyFormat,
779785
boolean isLicenseValidForInference
780786
) {
@@ -852,7 +858,8 @@ private static ShardBulkInferenceActionFilter createFilter(
852858
createClusterService(useLegacyFormat),
853859
inferenceServiceRegistry,
854860
modelRegistry,
855-
licenseState
861+
licenseState,
862+
indexingPressure
856863
);
857864
}
858865

@@ -1035,11 +1042,11 @@ boolean hasResult(String text) {
10351042
private static class InstrumentedIndexingPressure extends IndexingPressure {
10361043
private Coordinating coordinating = null;
10371044

1038-
InstrumentedIndexingPressure(Settings settings) {
1045+
private InstrumentedIndexingPressure(Settings settings) {
10391046
super(settings);
10401047
}
10411048

1042-
public Coordinating getCoordinating() {
1049+
private Coordinating getCoordinating() {
10431050
return coordinating;
10441051
}
10451052

@@ -1049,4 +1056,24 @@ public Coordinating createCoordinatingOperation(boolean forceExecution) {
10491056
return coordinating;
10501057
}
10511058
}
1059+
1060+
private static class NoopIndexingPressure extends IndexingPressure {
1061+
private NoopIndexingPressure() {
1062+
super(Settings.EMPTY);
1063+
}
1064+
1065+
@Override
1066+
public Coordinating createCoordinatingOperation(boolean forceExecution) {
1067+
return new NoopCoordinating(forceExecution);
1068+
}
1069+
1070+
private class NoopCoordinating extends Coordinating {
1071+
private NoopCoordinating(boolean forceExecution) {
1072+
super(forceExecution);
1073+
}
1074+
1075+
@Override
1076+
public void increment(int operations, long bytes) {}
1077+
}
1078+
}
10521079
}

0 commit comments

Comments
 (0)