Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/124313.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 124313
summary: Optimize memory usage in `ShardBulkInferenceActionFilter`
area: Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.SourceFieldMapper;
Expand All @@ -44,6 +45,7 @@
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -85,7 +87,12 @@ public void setup() throws Exception {

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
return Settings.builder()
.put(otherSettings)
.put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial")
.put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes))
.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
import java.util.function.Supplier;

import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;

public class InferencePlugin extends Plugin
Expand Down Expand Up @@ -445,6 +446,7 @@ public List<Setting<?>> getSettings() {
settings.addAll(Truncator.getSettingsDefinitions());
settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions());
settings.add(SKIP_VALIDATE_AND_START);
settings.add(INDICES_INFERENCE_BATCH_SIZE);
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());

return settings;
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -267,37 +267,38 @@ private static List<Chunk> parseChunksArrayLegacy(XContentParser parser, ParserC
/**
* Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
*/
public static List<Chunk> toSemanticTextFieldChunks(
String input,
int offsetAdjustment,
ChunkedInference results,
XContentType contentType,
boolean useLegacyFormat
) throws IOException {
public static List<Chunk> toSemanticTextFieldChunks(int offsetAdjustment, ChunkedInference results, XContentType contentType)
throws IOException {
List<Chunk> chunks = new ArrayList<>();
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
while (it.hasNext()) {
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useLegacyFormat));
chunks.add(toSemanticTextFieldChunk(offsetAdjustment, it.next()));
}
return chunks;
}

public static Chunk toSemanticTextFieldChunk(
String input,
int offsetAdjustment,
ChunkedInference.Chunk chunk,
boolean useLegacyFormat
) {
/**
* Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
*/
public static Chunk toSemanticTextFieldChunk(int offsetAdjustment, ChunkedInference.Chunk chunk) {
String text = null;
int startOffset = -1;
int endOffset = -1;
if (useLegacyFormat) {
text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
} else {
startOffset = chunk.textOffset().start() + offsetAdjustment;
endOffset = chunk.textOffset().end() + offsetAdjustment;
int startOffset = chunk.textOffset().start() + offsetAdjustment;
int endOffset = chunk.textOffset().end() + offsetAdjustment;
return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
}

public static List<Chunk> toSemanticTextFieldChunksLegacy(String input, ChunkedInference results, XContentType contentType)
throws IOException {
List<Chunk> chunks = new ArrayList<>();
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
while (it.hasNext()) {
chunks.add(toSemanticTextFieldChunkLegacy(input, it.next()));
}
return chunks;
}

return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
public static Chunk toSemanticTextFieldChunkLegacy(String input, ChunkedInference.Chunk chunk) {
var text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
return new Chunk(text, -1, -1, chunk.bytesReference());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.IndexVersion;
Expand Down Expand Up @@ -65,12 +67,13 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName;
Expand Down Expand Up @@ -115,7 +118,7 @@ public void tearDownThreadPool() throws Exception {

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testFilterNoop() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true);
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
Expand All @@ -141,7 +144,7 @@ public void testFilterNoop() throws Exception {
@SuppressWarnings({ "unchecked", "rawtypes" })
public void testLicenseInvalidForInference() throws InterruptedException {
StaticModel model = StaticModel.createRandomInstance();
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false);
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
Expand Down Expand Up @@ -182,7 +185,6 @@ public void testInferenceNotFound() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat,
true
);
Expand Down Expand Up @@ -229,7 +231,6 @@ public void testItemFailures() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat,
true
);
Expand Down Expand Up @@ -300,7 +301,6 @@ public void testExplicitNull() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat,
true
);
Expand Down Expand Up @@ -371,7 +371,6 @@ public void testHandleEmptyInput() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat,
true
);
Expand Down Expand Up @@ -444,13 +443,7 @@ public void testManyRandomDocs() throws Exception {
modifiedRequests[id] = res[1];
}

ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
inferenceModelMap,
randomIntBetween(10, 30),
useLegacyFormat,
true
);
ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
Expand Down Expand Up @@ -484,7 +477,6 @@ public void testManyRandomDocs() throws Exception {
private static ShardBulkInferenceActionFilter createFilter(
ThreadPool threadPool,
Map<String, StaticModel> modelMap,
int batchSize,
boolean useLegacyFormat,
boolean isLicenseValidForInference
) {
Expand Down Expand Up @@ -551,26 +543,28 @@ private static ShardBulkInferenceActionFilter createFilter(
createClusterService(useLegacyFormat),
inferenceServiceRegistry,
modelRegistry,
licenseState,
batchSize
licenseState
);
}

private static ClusterService createClusterService(boolean useLegacyFormat) {
IndexMetadata indexMetadata = mock(IndexMetadata.class);
var settings = Settings.builder()
var indexSettings = Settings.builder()
.put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), IndexVersion.current())
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat)
.build();
when(indexMetadata.getSettings()).thenReturn(settings);
when(indexMetadata.getSettings()).thenReturn(indexSettings);

Metadata metadata = mock(Metadata.class);
when(metadata.index(any(String.class))).thenReturn(indexMetadata);

ClusterState clusterState = ClusterState.builder(new ClusterName("test")).metadata(metadata).build();
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.state()).thenReturn(clusterState);

long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
Settings settings = Settings.builder().put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes)).build();
when(clusterService.getSettings()).thenReturn(settings);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(INDICES_INFERENCE_BATCH_SIZE)));
return clusterService;
}

Expand All @@ -581,7 +575,8 @@ private static BulkItemRequest[] randomBulkItemRequest(
) throws IOException {
Map<String, Object> docMap = new LinkedHashMap<>();
Map<String, Object> expectedDocMap = new LinkedHashMap<>();
XContentType requestContentType = randomFrom(XContentType.values());
// force JSON to avoid double/float conversions
XContentType requestContentType = XContentType.JSON;

Map<String, Object> inferenceMetadataFields = new HashMap<>();
for (var entry : fieldInferenceMap.values()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunk;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunkLegacy;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;

Expand Down Expand Up @@ -259,7 +260,7 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults(
while (inputsIt.hasNext() && chunkIt.hasNext()) {
String input = inputsIt.next();
var chunk = chunkIt.next();
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, chunk, useLegacyFormat));
chunks.add(useLegacyFormat ? toSemanticTextFieldChunkLegacy(input, chunk) : toSemanticTextFieldChunk(offsetAdjustment, chunk));

// When using the inference metadata fields format, all the input values are concatenated so that the
// chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
Expand Down