Skip to content

Commit c85718d

Browse files
committed
Address review comments and copy the source in place when possible
1 parent a12c9f1 commit c85718d

File tree

2 files changed

+128
-130
lines changed

2 files changed

+128
-130
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 126 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
2626
import org.elasticsearch.cluster.metadata.ProjectMetadata;
2727
import org.elasticsearch.cluster.service.ClusterService;
28+
import org.elasticsearch.common.bytes.BytesArray;
2829
import org.elasticsearch.common.bytes.BytesReference;
30+
import org.elasticsearch.common.bytes.CompositeBytesReference;
2931
import org.elasticsearch.common.settings.Setting;
3032
import org.elasticsearch.common.unit.ByteSizeValue;
3133
import org.elasticsearch.common.util.concurrent.AtomicArray;
@@ -52,6 +54,7 @@
5254
import org.elasticsearch.tasks.Task;
5355
import org.elasticsearch.xcontent.XContent;
5456
import org.elasticsearch.xcontent.XContentBuilder;
57+
import org.elasticsearch.xcontent.XContentFactory;
5558
import org.elasticsearch.xcontent.XContentParser;
5659
import org.elasticsearch.xcontent.XContentParserConfiguration;
5760
import org.elasticsearch.xcontent.XContentType;
@@ -469,8 +472,8 @@ private void recordRequestCountMetrics(Model model, int incrementBy, Throwable t
469472
* Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap}
470473
* for the specified {@code item}.
471474
*
472-
* @param item The bulk request item to process.
473-
* @param itemIndex The position of the item within the original bulk request.
475+
* @param item The bulk request item to process.
476+
* @param itemIndex The position of the item within the original bulk request.
474477
* @param requestsMap A map storing inference requests, where each key is an inference ID,
475478
* and the value is a list of associated {@link FieldInferenceRequest} objects.
476479
* @return The total content length of all newly added requests, or {@code 0} if no requests were added.
@@ -671,27 +674,137 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
671674
);
672675
inferenceFieldsMap.put(fieldName, result);
673676
}
674-
675-
BytesReference originalSource = indexRequest.source();
676677
if (useLegacyFormat) {
677678
var newDocMap = indexRequest.sourceAsMap();
678679
for (var entry : inferenceFieldsMap.entrySet()) {
679680
SemanticTextUtils.insertValue(entry.getKey(), newDocMap, entry.getValue());
680681
}
681-
indexRequest.source(newDocMap, indexRequest.getContentType());
682+
XContentBuilder builder = XContentFactory.contentBuilder(indexRequest.getContentType());
683+
builder.map(newDocMap);
684+
var newSource = BytesReference.bytes(builder);
685+
if (incrementIndexingPressure(item, indexRequest, newSource.length())) {
686+
indexRequest.source(newSource, indexRequest.getContentType());
687+
}
688+
} else {
689+
updateSourceWithInferenceFields(item, indexRequest, inferenceFieldsMap);
690+
}
691+
}
692+
693+
/**
694+
* Updates the {@link IndexRequest}'s source to include additional inference fields.
695+
* <p>
696+
* If the original source uses an array-backed {@link BytesReference}, this method attempts an in-place update,
697+
* reusing the existing array where possible and appending additional bytes only if needed.
698+
* <p>
699+
* If the original source is not array-backed, the entire source is replaced with the new source that includes
700+
* the inference fields. In this case, the full size of the new source is accounted for in indexing pressure.
701+
* <p>
702+
* Note: We do not subtract the indexing pressure of the original source since its bytes may be pooled and not
703+
* reclaimable by the garbage collector during the request lifecycle.
704+
*
705+
* @param item The {@link BulkItemRequest} being processed.
706+
* @param indexRequest The {@link IndexRequest} whose source will be updated.
707+
* @param inferenceFieldsMap A map of additional fields to append to the source.
708+
* @throws IOException if building the new source fails.
709+
*/
710+
private void updateSourceWithInferenceFields(
711+
BulkItemRequest item,
712+
IndexRequest indexRequest,
713+
Map<String, Object> inferenceFieldsMap
714+
) throws IOException {
715+
var originalSource = indexRequest.source();
716+
final BytesReference newSource;
717+
718+
// Build a new source by appending the inference fields to the existing source.
719+
try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) {
720+
appendSourceAndInferenceMetadata(builder, originalSource, indexRequest.getContentType(), inferenceFieldsMap);
721+
newSource = BytesReference.bytes(builder);
722+
}
723+
724+
// Calculate the additional size to account for in indexing pressure.
725+
final int additionalSize = originalSource.hasArray() ? newSource.length() - originalSource.length() : newSource.length();
726+
727+
// If we exceed the indexing pressure limit, do not proceed with the update.
728+
if (incrementIndexingPressure(item, indexRequest, additionalSize) == false) {
729+
return;
730+
}
731+
732+
// Apply the updated source to the index request.
733+
if (originalSource.hasArray()) {
734+
// If the original source is backed by an array, perform in-place update:
735+
// - Copy as much of the new source as fits into the original array.
736+
System.arraycopy(
737+
newSource.array(),
738+
newSource.arrayOffset(),
739+
originalSource.array(),
740+
originalSource.arrayOffset(),
741+
originalSource.length()
742+
);
743+
744+
int remainingSize = newSource.length() - originalSource.length();
745+
if (remainingSize > 0) {
746+
// If there are additional bytes, append them as a new BytesArray segment.
747+
byte[] remainingBytes = new byte[remainingSize];
748+
System.arraycopy(
749+
newSource.array(),
750+
newSource.arrayOffset() + originalSource.length(),
751+
remainingBytes,
752+
0,
753+
remainingSize
754+
);
755+
indexRequest.source(
756+
CompositeBytesReference.of(originalSource, new BytesArray(remainingBytes)),
757+
indexRequest.getContentType()
758+
);
759+
} else {
760+
// No additional bytes; just adjust the slice length.
761+
indexRequest.source(originalSource.slice(0, newSource.length()));
762+
}
682763
} else {
683-
try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) {
684-
appendSourceAndInferenceMetadata(builder, indexRequest.source(), indexRequest.getContentType(), inferenceFieldsMap);
685-
indexRequest.source(builder);
764+
// If the original source is not array-backed, replace it entirely.
765+
indexRequest.source(newSource, indexRequest.getContentType());
766+
}
767+
}
768+
769+
/**
770+
* Appends the original source and the new inference metadata field directly to the provided
771+
* {@link XContentBuilder}, avoiding the need to materialize the original source as a {@link Map}.
772+
*/
773+
private void appendSourceAndInferenceMetadata(
774+
XContentBuilder builder,
775+
BytesReference source,
776+
XContentType xContentType,
777+
Map<String, Object> inferenceFieldsMap
778+
) throws IOException {
779+
builder.startObject();
780+
781+
// append the original source
782+
try (
783+
XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, source, xContentType)
784+
) {
785+
// skip start object
786+
parser.nextToken();
787+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
788+
builder.copyCurrentStructure(parser);
686789
}
687790
}
688-
long modifiedSourceSize = indexRequest.source().ramBytesUsed();
689791

690-
// Add the indexing pressure from the source modifications.
792+
// add the inference metadata field
793+
builder.field(InferenceMetadataFieldsMapper.NAME);
794+
try (XContentParser parser = XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, inferenceFieldsMap)) {
795+
builder.copyCurrentStructure(parser);
796+
}
797+
798+
builder.endObject();
799+
}
800+
801+
private boolean incrementIndexingPressure(BulkItemRequest item, IndexRequest indexRequest, int inc) {
691802
try {
692-
coordinatingIndexingPressure.increment(1, modifiedSourceSize - originalSource.ramBytesUsed());
803+
if (inc > 0) {
804+
coordinatingIndexingPressure.increment(1, inc);
805+
}
806+
return true;
693807
} catch (EsRejectedExecutionException e) {
694-
indexRequest.source(originalSource, indexRequest.getContentType());
695808
inferenceStats.bulkRejection().incrementBy(1);
696809
item.abort(
697810
item.index(),
@@ -702,40 +815,11 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
702815
e
703816
)
704817
);
818+
return false;
705819
}
706820
}
707821
}
708822

709-
/**
710-
* Appends the original source and the new inference metadata field directly to the provided
711-
* {@link XContentBuilder}, avoiding the need to materialize the original source as a {@link Map}.
712-
*/
713-
private static void appendSourceAndInferenceMetadata(
714-
XContentBuilder builder,
715-
BytesReference source,
716-
XContentType xContentType,
717-
Map<String, Object> inferenceFieldsMap
718-
) throws IOException {
719-
builder.startObject();
720-
721-
// append the original source
722-
try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, source, xContentType)) {
723-
// skip start object
724-
parser.nextToken();
725-
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
726-
builder.copyCurrentStructure(parser);
727-
}
728-
}
729-
730-
// add the inference metadata field
731-
builder.field(InferenceMetadataFieldsMapper.NAME);
732-
try (XContentParser parser = XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, inferenceFieldsMap)) {
733-
builder.copyCurrentStructure(parser);
734-
}
735-
736-
builder.endObject();
737-
}
738-
739823
static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) {
740824
if (docWriteRequest instanceof IndexRequest indexRequest) {
741825
return indexRequest;

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 2 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,7 @@ public ShardBulkInferenceActionFilterTests(boolean useLegacyFormat) {
131131

132132
@ParametersFactory
133133
public static Iterable<Object[]> parameters() throws Exception {
134-
List<Object[]> lst = new ArrayList<>();
135-
lst.add(new Object[] { true });
136-
return lst;
134+
return List.of(new Boolean[] { true }, new Boolean[] { false });
137135
}
138136

139137
@Before
@@ -616,10 +614,7 @@ public void testIndexingPressure() throws Exception {
616614

617615
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
618616
assertThat(coordinatingIndexingPressure, notNullValue());
619-
verify(coordinatingIndexingPressure, times(6)).increment(eq(1), longThat(l -> l > 0));
620-
if (useLegacyFormat == false) {
621-
verify(coordinatingIndexingPressure).increment(1, longThat(l -> l > bytesUsed(doc1UpdateSource)));
622-
}
617+
verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(eq(1), longThat(l -> l > 0));
623618

624619
// Verify that the only times that increment is called are the times verified above
625620
verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(anyInt(), anyLong());
@@ -668,87 +663,6 @@ public void testIndexingPressure() throws Exception {
668663
verify(coordinatingIndexingPressure).close();
669664
}
670665

671-
@SuppressWarnings("unchecked")
672-
public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Exception {
673-
final InferenceStats inferenceStats = new InferenceStats(mock(), mock(), mock());
674-
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
675-
Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build()
676-
);
677-
final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
678-
final ShardBulkInferenceActionFilter filter = createFilter(
679-
threadPool,
680-
Map.of(sparseModel.getInferenceEntityId(), sparseModel),
681-
indexingPressure,
682-
useLegacyFormat,
683-
true,
684-
inferenceStats
685-
);
686-
687-
XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
688-
689-
CountDownLatch chainExecuted = new CountDownLatch(1);
690-
ActionFilterChain<BulkShardRequest, BulkShardResponse> actionFilterChain = (task, action, request, listener) -> {
691-
try {
692-
assertNull(request.getInferenceFieldMap());
693-
assertThat(request.items().length, equalTo(3));
694-
695-
assertNull(request.items()[0].getPrimaryResponse());
696-
assertNull(request.items()[2].getPrimaryResponse());
697-
698-
BulkItemRequest doc1Request = request.items()[1];
699-
BulkItemResponse doc1Response = doc1Request.getPrimaryResponse();
700-
assertNotNull(doc1Response);
701-
assertTrue(doc1Response.isFailed());
702-
BulkItemResponse.Failure doc1Failure = doc1Response.getFailure();
703-
assertThat(
704-
doc1Failure.getCause().getMessage(),
705-
containsString("Unable to insert inference results into document [doc_1] due to memory pressure.")
706-
);
707-
assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class));
708-
assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS));
709-
verify(inferenceStats.bulkRejection()).incrementBy(1);
710-
711-
IndexRequest doc1IndexRequest = getIndexRequestOrNull(doc1Request.request());
712-
assertThat(doc1IndexRequest, notNullValue());
713-
assertThat(doc1IndexRequest.source(), equalTo(BytesReference.bytes(doc1Source)));
714-
715-
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
716-
assertThat(coordinatingIndexingPressure, notNullValue());
717-
verify(coordinatingIndexingPressure).increment(eq(1), longThat(l -> l > bytesUsed(doc1Source)));
718-
verify(coordinatingIndexingPressure, times(1)).increment(anyInt(), anyLong());
719-
720-
// Verify that the coordinating indexing pressure is maintained through downstream action filters
721-
verify(coordinatingIndexingPressure, never()).close();
722-
723-
// Call the listener once the request is successfully processed, like is done in the production code path
724-
listener.onResponse(null);
725-
} finally {
726-
chainExecuted.countDown();
727-
}
728-
};
729-
ActionListener<BulkShardResponse> actionListener = (ActionListener<BulkShardResponse>) mock(ActionListener.class);
730-
Task task = mock(Task.class);
731-
732-
Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
733-
"sparse_field",
734-
new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null)
735-
);
736-
737-
BulkItemRequest[] items = new BulkItemRequest[3];
738-
items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo"));
739-
items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source));
740-
items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source("non_inference_field", "baz"));
741-
742-
BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
743-
request.setInferenceFieldMap(inferenceFieldMap);
744-
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
745-
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
746-
747-
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
748-
assertThat(coordinatingIndexingPressure, notNullValue());
749-
verify(coordinatingIndexingPressure).close();
750-
}
751-
752666
@SuppressWarnings("unchecked")
753667
public void testIndexingPressureTripsOnInferenceResponseHandling() throws Exception {
754668
final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");

0 commit comments

Comments
 (0)