Skip to content

Commit 8b0ce53

Browse files
committed
Prevent duplicate source parsing in ShardBulkInferenceActionFilter.
1 parent d266f58 commit 8b0ce53

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ private record FieldInferenceResponse(
180180
private record FieldInferenceResponseAccumulator(
181181
int id,
182182
Map<String, List<FieldInferenceResponse>> responses,
183-
List<Exception> failures
183+
List<Exception> failures,
184+
Map<String, Object> source
184185
) {
185186
void addOrUpdateResponse(FieldInferenceResponse response) {
186187
synchronized (this) {
@@ -376,17 +377,17 @@ private void onFinish() {
376377
.chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), InputType.INGEST, TimeValue.MAX_VALUE, completionListener);
377378
}
378379

379-
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
380+
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id, Map<String, Object> source) {
380381
FieldInferenceResponseAccumulator acc = inferenceResults.get(id);
381382
if (acc == null) {
382-
acc = new FieldInferenceResponseAccumulator(id, new HashMap<>(), new ArrayList<>());
383+
acc = new FieldInferenceResponseAccumulator(id, new HashMap<>(), new ArrayList<>(), source);
383384
inferenceResults.set(id, acc);
384385
}
385386
return acc;
386387
}
387388

388389
private void addInferenceResponseFailure(int id, Exception failure) {
389-
var acc = ensureResponseAccumulatorSlot(id);
390+
var acc = ensureResponseAccumulatorSlot(id, null);
390391
acc.addFailure(failure);
391392
}
392393

@@ -404,7 +405,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
404405
}
405406

406407
final IndexRequest indexRequest = getIndexRequestOrNull(item.request());
407-
var newDocMap = indexRequest.sourceAsMap();
408+
var newDocMap = response.source();
408409
Map<String, Object> inferenceFieldsMap = new HashMap<>();
409410
for (var entry : response.responses.entrySet()) {
410411
var fieldName = entry.getKey();
@@ -542,7 +543,7 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
542543
* This ensures that the field is treated as intentionally cleared,
543544
* preventing any unintended carryover of prior inference results.
544545
*/
545-
var slot = ensureResponseAccumulatorSlot(itemIndex);
546+
var slot = ensureResponseAccumulatorSlot(itemIndex, docMap);
546547
slot.addOrUpdateResponse(
547548
new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
548549
);
@@ -563,7 +564,7 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
563564
}
564565
continue;
565566
}
566-
ensureResponseAccumulatorSlot(itemIndex);
567+
ensureResponseAccumulatorSlot(itemIndex, docMap);
567568
final List<String> values;
568569
try {
569570
values = SemanticTextUtils.nodeStringValues(field, valueObj);

0 commit comments

Comments
 (0)