From 8b0ce533e1c45c87d90db64af7f2687ef6bcc243 Mon Sep 17 00:00:00 2001 From: Jan Kuipers Date: Wed, 5 Mar 2025 16:52:10 +0100 Subject: [PATCH] Prevent duplicate source parsing in ShardBulkInferenceActionFilter. --- .../filter/ShardBulkInferenceActionFilter.java | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 7783e8599279d..f30d4f4f58bd4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -180,7 +180,8 @@ private record FieldInferenceResponse( private record FieldInferenceResponseAccumulator( int id, Map> responses, - List failures + List failures, + Map source ) { void addOrUpdateResponse(FieldInferenceResponse response) { synchronized (this) { @@ -376,17 +377,17 @@ private void onFinish() { .chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), InputType.INGEST, TimeValue.MAX_VALUE, completionListener); } - private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { + private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id, Map source) { FieldInferenceResponseAccumulator acc = inferenceResults.get(id); if (acc == null) { - acc = new FieldInferenceResponseAccumulator(id, new HashMap<>(), new ArrayList<>()); + acc = new FieldInferenceResponseAccumulator(id, new HashMap<>(), new ArrayList<>(), source); inferenceResults.set(id, acc); } return acc; } private void addInferenceResponseFailure(int id, Exception failure) { - var acc = ensureResponseAccumulatorSlot(id); + var acc = ensureResponseAccumulatorSlot(id, null); acc.addFailure(failure); } @@ -404,7 +405,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons } final IndexRequest indexRequest = getIndexRequestOrNull(item.request()); - var newDocMap = indexRequest.sourceAsMap(); + var newDocMap = response.source(); Map inferenceFieldsMap = new HashMap<>(); for (var entry : response.responses.entrySet()) { var fieldName = entry.getKey(); @@ -542,7 +543,7 @@ private Map> createFieldInferenceRequests(Bu * This ensures that the field is treated as intentionally cleared, * preventing any unintended carryover of prior inference results. */ - var slot = ensureResponseAccumulatorSlot(itemIndex); + var slot = ensureResponseAccumulatorSlot(itemIndex, docMap); slot.addOrUpdateResponse( new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE) ); @@ -563,7 +564,7 @@ private Map> createFieldInferenceRequests(Bu } continue; } - ensureResponseAccumulatorSlot(itemIndex); + ensureResponseAccumulatorSlot(itemIndex, docMap); final List values; try { values = SemanticTextUtils.nodeStringValues(field, valueObj);