Skip to content

Commit 17e2721

Browse files
authored
Optimize memory usage in ShardBulkInferenceActionFilter (#124313) (#124863)
This refactor improves memory efficiency by processing inference requests in batches, capped by a max input length. Changes include: - A new dynamic operator setting to control the maximum batch size in bytes. - Dropping input data from inference responses when the legacy semantic text format isn’t used, saving memory. - Clearing inference results dynamically after each bulk item to free up memory sooner. This is a step toward enabling circuit breakers to better handle memory usage when dealing with large inputs.
1 parent 4e76ffb commit 17e2721

File tree

7 files changed

+313
-245
lines changed

7 files changed

+313
-245
lines changed

docs/changelog/124313.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 124313
2+
summary: Optimize memory usage in `ShardBulkInferenceActionFilter`
3+
area: Search
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.action.update.UpdateRequestBuilder;
2121
import org.elasticsearch.cluster.metadata.IndexMetadata;
2222
import org.elasticsearch.common.settings.Settings;
23+
import org.elasticsearch.common.unit.ByteSizeValue;
2324
import org.elasticsearch.index.IndexSettings;
2425
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
2526
import org.elasticsearch.index.mapper.SourceFieldMapper;
@@ -44,6 +45,7 @@
4445
import java.util.Map;
4546
import java.util.Set;
4647

48+
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
4749
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
4850
import static org.hamcrest.Matchers.containsString;
4951
import static org.hamcrest.Matchers.equalTo;
@@ -85,7 +87,12 @@ public void setup() throws Exception {
8587

8688
@Override
8789
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
88-
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
90+
long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
91+
return Settings.builder()
92+
.put(otherSettings)
93+
.put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial")
94+
.put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes))
95+
.build();
8996
}
9097

9198
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@
142142
import java.util.function.Supplier;
143143

144144
import static java.util.Collections.singletonList;
145+
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
145146
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;
146147

147148
public class InferencePlugin extends Plugin
@@ -445,6 +446,7 @@ public List<Setting<?>> getSettings() {
445446
settings.addAll(Truncator.getSettingsDefinitions());
446447
settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions());
447448
settings.add(SKIP_VALIDATE_AND_START);
449+
settings.add(INDICES_INFERENCE_BATCH_SIZE);
448450
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());
449451

450452
return settings;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 257 additions & 200 deletions
Large diffs are not rendered by default.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -267,37 +267,38 @@ private static List<Chunk> parseChunksArrayLegacy(XContentParser parser, ParserC
267267
/**
268268
* Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
269269
*/
270-
public static List<Chunk> toSemanticTextFieldChunks(
271-
String input,
272-
int offsetAdjustment,
273-
ChunkedInference results,
274-
XContentType contentType,
275-
boolean useLegacyFormat
276-
) throws IOException {
270+
public static List<Chunk> toSemanticTextFieldChunks(int offsetAdjustment, ChunkedInference results, XContentType contentType)
271+
throws IOException {
277272
List<Chunk> chunks = new ArrayList<>();
278273
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
279274
while (it.hasNext()) {
280-
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useLegacyFormat));
275+
chunks.add(toSemanticTextFieldChunk(offsetAdjustment, it.next()));
281276
}
282277
return chunks;
283278
}
284279

285-
public static Chunk toSemanticTextFieldChunk(
286-
String input,
287-
int offsetAdjustment,
288-
ChunkedInference.Chunk chunk,
289-
boolean useLegacyFormat
290-
) {
280+
/**
281+
* Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
282+
*/
283+
public static Chunk toSemanticTextFieldChunk(int offsetAdjustment, ChunkedInference.Chunk chunk) {
291284
String text = null;
292-
int startOffset = -1;
293-
int endOffset = -1;
294-
if (useLegacyFormat) {
295-
text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
296-
} else {
297-
startOffset = chunk.textOffset().start() + offsetAdjustment;
298-
endOffset = chunk.textOffset().end() + offsetAdjustment;
285+
int startOffset = chunk.textOffset().start() + offsetAdjustment;
286+
int endOffset = chunk.textOffset().end() + offsetAdjustment;
287+
return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
288+
}
289+
290+
public static List<Chunk> toSemanticTextFieldChunksLegacy(String input, ChunkedInference results, XContentType contentType)
291+
throws IOException {
292+
List<Chunk> chunks = new ArrayList<>();
293+
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
294+
while (it.hasNext()) {
295+
chunks.add(toSemanticTextFieldChunkLegacy(input, it.next()));
299296
}
297+
return chunks;
298+
}
300299

301-
return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
300+
public static Chunk toSemanticTextFieldChunkLegacy(String input, ChunkedInference.Chunk chunk) {
301+
var text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
302+
return new Chunk(text, -1, -1, chunk.bytesReference());
302303
}
303304
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
import org.elasticsearch.cluster.metadata.Metadata;
2828
import org.elasticsearch.cluster.service.ClusterService;
2929
import org.elasticsearch.common.Strings;
30+
import org.elasticsearch.common.settings.ClusterSettings;
3031
import org.elasticsearch.common.settings.Settings;
32+
import org.elasticsearch.common.unit.ByteSizeValue;
3133
import org.elasticsearch.common.xcontent.XContentHelper;
3234
import org.elasticsearch.common.xcontent.support.XContentMapValues;
3335
import org.elasticsearch.index.IndexVersion;
@@ -65,12 +67,13 @@
6567
import java.util.List;
6668
import java.util.Map;
6769
import java.util.Optional;
70+
import java.util.Set;
6871
import java.util.concurrent.CountDownLatch;
6972
import java.util.concurrent.TimeUnit;
7073

7174
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
7275
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
73-
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE;
76+
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
7477
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
7578
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName;
7679
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName;
@@ -115,7 +118,7 @@ public void tearDownThreadPool() throws Exception {
115118

116119
@SuppressWarnings({ "unchecked", "rawtypes" })
117120
public void testFilterNoop() throws Exception {
118-
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true);
121+
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true);
119122
CountDownLatch chainExecuted = new CountDownLatch(1);
120123
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
121124
try {
@@ -141,7 +144,7 @@ public void testFilterNoop() throws Exception {
141144
@SuppressWarnings({ "unchecked", "rawtypes" })
142145
public void testLicenseInvalidForInference() throws InterruptedException {
143146
StaticModel model = StaticModel.createRandomInstance();
144-
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false);
147+
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false);
145148
CountDownLatch chainExecuted = new CountDownLatch(1);
146149
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
147150
try {
@@ -182,7 +185,6 @@ public void testInferenceNotFound() throws Exception {
182185
ShardBulkInferenceActionFilter filter = createFilter(
183186
threadPool,
184187
Map.of(model.getInferenceEntityId(), model),
185-
randomIntBetween(1, 10),
186188
useLegacyFormat,
187189
true
188190
);
@@ -229,7 +231,6 @@ public void testItemFailures() throws Exception {
229231
ShardBulkInferenceActionFilter filter = createFilter(
230232
threadPool,
231233
Map.of(model.getInferenceEntityId(), model),
232-
randomIntBetween(1, 10),
233234
useLegacyFormat,
234235
true
235236
);
@@ -300,7 +301,6 @@ public void testExplicitNull() throws Exception {
300301
ShardBulkInferenceActionFilter filter = createFilter(
301302
threadPool,
302303
Map.of(model.getInferenceEntityId(), model),
303-
randomIntBetween(1, 10),
304304
useLegacyFormat,
305305
true
306306
);
@@ -371,7 +371,6 @@ public void testHandleEmptyInput() throws Exception {
371371
ShardBulkInferenceActionFilter filter = createFilter(
372372
threadPool,
373373
Map.of(model.getInferenceEntityId(), model),
374-
randomIntBetween(1, 10),
375374
useLegacyFormat,
376375
true
377376
);
@@ -444,13 +443,7 @@ public void testManyRandomDocs() throws Exception {
444443
modifiedRequests[id] = res[1];
445444
}
446445

447-
ShardBulkInferenceActionFilter filter = createFilter(
448-
threadPool,
449-
inferenceModelMap,
450-
randomIntBetween(10, 30),
451-
useLegacyFormat,
452-
true
453-
);
446+
ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true);
454447
CountDownLatch chainExecuted = new CountDownLatch(1);
455448
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
456449
try {
@@ -484,7 +477,6 @@ public void testManyRandomDocs() throws Exception {
484477
private static ShardBulkInferenceActionFilter createFilter(
485478
ThreadPool threadPool,
486479
Map<String, StaticModel> modelMap,
487-
int batchSize,
488480
boolean useLegacyFormat,
489481
boolean isLicenseValidForInference
490482
) {
@@ -551,26 +543,28 @@ private static ShardBulkInferenceActionFilter createFilter(
551543
createClusterService(useLegacyFormat),
552544
inferenceServiceRegistry,
553545
modelRegistry,
554-
licenseState,
555-
batchSize
546+
licenseState
556547
);
557548
}
558549

559550
private static ClusterService createClusterService(boolean useLegacyFormat) {
560551
IndexMetadata indexMetadata = mock(IndexMetadata.class);
561-
var settings = Settings.builder()
552+
var indexSettings = Settings.builder()
562553
.put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), IndexVersion.current())
563554
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat)
564555
.build();
565-
when(indexMetadata.getSettings()).thenReturn(settings);
556+
when(indexMetadata.getSettings()).thenReturn(indexSettings);
566557

567558
Metadata metadata = mock(Metadata.class);
568559
when(metadata.index(any(String.class))).thenReturn(indexMetadata);
569560

570561
ClusterState clusterState = ClusterState.builder(new ClusterName("test")).metadata(metadata).build();
571562
ClusterService clusterService = mock(ClusterService.class);
572563
when(clusterService.state()).thenReturn(clusterState);
573-
564+
long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
565+
Settings settings = Settings.builder().put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes)).build();
566+
when(clusterService.getSettings()).thenReturn(settings);
567+
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(INDICES_INFERENCE_BATCH_SIZE)));
574568
return clusterService;
575569
}
576570

@@ -581,7 +575,8 @@ private static BulkItemRequest[] randomBulkItemRequest(
581575
) throws IOException {
582576
Map<String, Object> docMap = new LinkedHashMap<>();
583577
Map<String, Object> expectedDocMap = new LinkedHashMap<>();
584-
XContentType requestContentType = randomFrom(XContentType.values());
578+
// force JSON to avoid double/float conversions
579+
XContentType requestContentType = XContentType.JSON;
585580

586581
Map<String, Object> inferenceMetadataFields = new HashMap<>();
587582
for (var entry : fieldInferenceMap.values()) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
4343
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunk;
44+
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunkLegacy;
4445
import static org.hamcrest.Matchers.containsString;
4546
import static org.hamcrest.Matchers.equalTo;
4647

@@ -259,7 +260,7 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults(
259260
while (inputsIt.hasNext() && chunkIt.hasNext()) {
260261
String input = inputsIt.next();
261262
var chunk = chunkIt.next();
262-
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, chunk, useLegacyFormat));
263+
chunks.add(useLegacyFormat ? toSemanticTextFieldChunkLegacy(input, chunk) : toSemanticTextFieldChunk(offsetAdjustment, chunk));
263264

264265
// When using the inference metadata fields format, all the input values are concatenated so that the
265266
// chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment

0 commit comments

Comments
 (0)