Skip to content

Commit 361b51d

Browse files
authored
Optimize memory usage in ShardBulkInferenceActionFilter (#124313)
This refactor improves memory efficiency by processing inference requests in batches, capped by a max input length. Changes include: - A new dynamic operator setting to control the maximum batch size in bytes. - Dropping input data from inference responses when the legacy semantic text format isn’t used, saving memory. - Clearing inference results dynamically after each bulk item to free up memory sooner. This is a step toward enabling circuit breakers to better handle memory usage when dealing with large inputs.
1 parent 35ecbf6 commit 361b51d

File tree

7 files changed

+313
-245
lines changed

7 files changed

+313
-245
lines changed

docs/changelog/124313.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 124313
2+
summary: Optimize memory usage in `ShardBulkInferenceActionFilter`
3+
area: Search
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.action.update.UpdateRequestBuilder;
2121
import org.elasticsearch.cluster.metadata.IndexMetadata;
2222
import org.elasticsearch.common.settings.Settings;
23+
import org.elasticsearch.common.unit.ByteSizeValue;
2324
import org.elasticsearch.index.IndexSettings;
2425
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
2526
import org.elasticsearch.index.mapper.SourceFieldMapper;
@@ -44,6 +45,7 @@
4445
import java.util.Map;
4546
import java.util.Set;
4647

48+
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
4749
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
4850
import static org.hamcrest.Matchers.containsString;
4951
import static org.hamcrest.Matchers.equalTo;
@@ -85,7 +87,12 @@ public void setup() throws Exception {
8587

8688
@Override
8789
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
88-
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
90+
long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
91+
return Settings.builder()
92+
.put(otherSettings)
93+
.put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial")
94+
.put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes))
95+
.build();
8996
}
9097

9198
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@
142142
import java.util.function.Supplier;
143143

144144
import static java.util.Collections.singletonList;
145+
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
145146
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;
146147

147148
public class InferencePlugin extends Plugin
@@ -442,6 +443,7 @@ public List<Setting<?>> getSettings() {
442443
settings.addAll(Truncator.getSettingsDefinitions());
443444
settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions());
444445
settings.add(SKIP_VALIDATE_AND_START);
446+
settings.add(INDICES_INFERENCE_BATCH_SIZE);
445447
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());
446448

447449
return settings;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 257 additions & 200 deletions
Large diffs are not rendered by default.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -267,37 +267,38 @@ private static List<Chunk> parseChunksArrayLegacy(XContentParser parser, ParserC
267267
/**
268268
* Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
269269
*/
270-
public static List<Chunk> toSemanticTextFieldChunks(
271-
String input,
272-
int offsetAdjustment,
273-
ChunkedInference results,
274-
XContentType contentType,
275-
boolean useLegacyFormat
276-
) throws IOException {
270+
public static List<Chunk> toSemanticTextFieldChunks(int offsetAdjustment, ChunkedInference results, XContentType contentType)
271+
throws IOException {
277272
List<Chunk> chunks = new ArrayList<>();
278273
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
279274
while (it.hasNext()) {
280-
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useLegacyFormat));
275+
chunks.add(toSemanticTextFieldChunk(offsetAdjustment, it.next()));
281276
}
282277
return chunks;
283278
}
284279

285-
public static Chunk toSemanticTextFieldChunk(
286-
String input,
287-
int offsetAdjustment,
288-
ChunkedInference.Chunk chunk,
289-
boolean useLegacyFormat
290-
) {
280+
/**
281+
* Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
282+
*/
283+
public static Chunk toSemanticTextFieldChunk(int offsetAdjustment, ChunkedInference.Chunk chunk) {
291284
String text = null;
292-
int startOffset = -1;
293-
int endOffset = -1;
294-
if (useLegacyFormat) {
295-
text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
296-
} else {
297-
startOffset = chunk.textOffset().start() + offsetAdjustment;
298-
endOffset = chunk.textOffset().end() + offsetAdjustment;
285+
int startOffset = chunk.textOffset().start() + offsetAdjustment;
286+
int endOffset = chunk.textOffset().end() + offsetAdjustment;
287+
return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
288+
}
289+
290+
public static List<Chunk> toSemanticTextFieldChunksLegacy(String input, ChunkedInference results, XContentType contentType)
291+
throws IOException {
292+
List<Chunk> chunks = new ArrayList<>();
293+
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
294+
while (it.hasNext()) {
295+
chunks.add(toSemanticTextFieldChunkLegacy(input, it.next()));
299296
}
297+
return chunks;
298+
}
300299

301-
return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
300+
public static Chunk toSemanticTextFieldChunkLegacy(String input, ChunkedInference.Chunk chunk) {
301+
var text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
302+
return new Chunk(text, -1, -1, chunk.bytesReference());
302303
}
303304
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
import org.elasticsearch.cluster.metadata.ProjectMetadata;
2929
import org.elasticsearch.cluster.service.ClusterService;
3030
import org.elasticsearch.common.Strings;
31+
import org.elasticsearch.common.settings.ClusterSettings;
3132
import org.elasticsearch.common.settings.Settings;
33+
import org.elasticsearch.common.unit.ByteSizeValue;
3234
import org.elasticsearch.common.xcontent.XContentHelper;
3335
import org.elasticsearch.common.xcontent.support.XContentMapValues;
3436
import org.elasticsearch.index.IndexVersion;
@@ -66,12 +68,13 @@
6668
import java.util.List;
6769
import java.util.Map;
6870
import java.util.Optional;
71+
import java.util.Set;
6972
import java.util.concurrent.CountDownLatch;
7073
import java.util.concurrent.TimeUnit;
7174

7275
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
7376
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
74-
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE;
77+
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
7578
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
7679
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName;
7780
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName;
@@ -118,7 +121,7 @@ public void tearDownThreadPool() throws Exception {
118121

119122
@SuppressWarnings({ "unchecked", "rawtypes" })
120123
public void testFilterNoop() throws Exception {
121-
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true);
124+
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true);
122125
CountDownLatch chainExecuted = new CountDownLatch(1);
123126
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
124127
try {
@@ -144,7 +147,7 @@ public void testFilterNoop() throws Exception {
144147
@SuppressWarnings({ "unchecked", "rawtypes" })
145148
public void testLicenseInvalidForInference() throws InterruptedException {
146149
StaticModel model = StaticModel.createRandomInstance();
147-
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false);
150+
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false);
148151
CountDownLatch chainExecuted = new CountDownLatch(1);
149152
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
150153
try {
@@ -185,7 +188,6 @@ public void testInferenceNotFound() throws Exception {
185188
ShardBulkInferenceActionFilter filter = createFilter(
186189
threadPool,
187190
Map.of(model.getInferenceEntityId(), model),
188-
randomIntBetween(1, 10),
189191
useLegacyFormat,
190192
true
191193
);
@@ -232,7 +234,6 @@ public void testItemFailures() throws Exception {
232234
ShardBulkInferenceActionFilter filter = createFilter(
233235
threadPool,
234236
Map.of(model.getInferenceEntityId(), model),
235-
randomIntBetween(1, 10),
236237
useLegacyFormat,
237238
true
238239
);
@@ -303,7 +304,6 @@ public void testExplicitNull() throws Exception {
303304
ShardBulkInferenceActionFilter filter = createFilter(
304305
threadPool,
305306
Map.of(model.getInferenceEntityId(), model),
306-
randomIntBetween(1, 10),
307307
useLegacyFormat,
308308
true
309309
);
@@ -374,7 +374,6 @@ public void testHandleEmptyInput() throws Exception {
374374
ShardBulkInferenceActionFilter filter = createFilter(
375375
threadPool,
376376
Map.of(model.getInferenceEntityId(), model),
377-
randomIntBetween(1, 10),
378377
useLegacyFormat,
379378
true
380379
);
@@ -447,13 +446,7 @@ public void testManyRandomDocs() throws Exception {
447446
modifiedRequests[id] = res[1];
448447
}
449448

450-
ShardBulkInferenceActionFilter filter = createFilter(
451-
threadPool,
452-
inferenceModelMap,
453-
randomIntBetween(10, 30),
454-
useLegacyFormat,
455-
true
456-
);
449+
ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true);
457450
CountDownLatch chainExecuted = new CountDownLatch(1);
458451
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
459452
try {
@@ -487,7 +480,6 @@ public void testManyRandomDocs() throws Exception {
487480
private static ShardBulkInferenceActionFilter createFilter(
488481
ThreadPool threadPool,
489482
Map<String, StaticModel> modelMap,
490-
int batchSize,
491483
boolean useLegacyFormat,
492484
boolean isLicenseValidForInference
493485
) {
@@ -554,18 +546,17 @@ private static ShardBulkInferenceActionFilter createFilter(
554546
createClusterService(useLegacyFormat),
555547
inferenceServiceRegistry,
556548
modelRegistry,
557-
licenseState,
558-
batchSize
549+
licenseState
559550
);
560551
}
561552

562553
private static ClusterService createClusterService(boolean useLegacyFormat) {
563554
IndexMetadata indexMetadata = mock(IndexMetadata.class);
564-
var settings = Settings.builder()
555+
var indexSettings = Settings.builder()
565556
.put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), IndexVersion.current())
566557
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat)
567558
.build();
568-
when(indexMetadata.getSettings()).thenReturn(settings);
559+
when(indexMetadata.getSettings()).thenReturn(indexSettings);
569560

570561
ProjectMetadata project = spy(ProjectMetadata.builder(Metadata.DEFAULT_PROJECT_ID).build());
571562
when(project.index(anyString())).thenReturn(indexMetadata);
@@ -576,7 +567,10 @@ private static ClusterService createClusterService(boolean useLegacyFormat) {
576567
ClusterState clusterState = ClusterState.builder(new ClusterName("test")).metadata(metadata).build();
577568
ClusterService clusterService = mock(ClusterService.class);
578569
when(clusterService.state()).thenReturn(clusterState);
579-
570+
long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
571+
Settings settings = Settings.builder().put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes)).build();
572+
when(clusterService.getSettings()).thenReturn(settings);
573+
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(INDICES_INFERENCE_BATCH_SIZE)));
580574
return clusterService;
581575
}
582576

@@ -587,7 +581,8 @@ private static BulkItemRequest[] randomBulkItemRequest(
587581
) throws IOException {
588582
Map<String, Object> docMap = new LinkedHashMap<>();
589583
Map<String, Object> expectedDocMap = new LinkedHashMap<>();
590-
XContentType requestContentType = randomFrom(XContentType.values());
584+
// force JSON to avoid double/float conversions
585+
XContentType requestContentType = XContentType.JSON;
591586

592587
Map<String, Object> inferenceMetadataFields = new HashMap<>();
593588
for (var entry : fieldInferenceMap.values()) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
4343
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunk;
44+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunkLegacy;
4445
import static org.hamcrest.Matchers.containsString;
4546
import static org.hamcrest.Matchers.equalTo;
4647

@@ -274,7 +275,7 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults(
274275
while (inputsIt.hasNext() && chunkIt.hasNext()) {
275276
String input = inputsIt.next();
276277
var chunk = chunkIt.next();
277-
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, chunk, useLegacyFormat));
278+
chunks.add(useLegacyFormat ? toSemanticTextFieldChunkLegacy(input, chunk) : toSemanticTextFieldChunk(offsetAdjustment, chunk));
278279

279280
// When using the inference metadata fields format, all the input values are concatenated so that the
280281
// chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment

0 commit comments

Comments
 (0)