3939import org .elasticsearch .inference .UnparsedModel ;
4040import org .elasticsearch .rest .RestStatus ;
4141import org .elasticsearch .tasks .Task ;
42+ import org .elasticsearch .xcontent .XContent ;
4243import org .elasticsearch .xpack .core .inference .results .ChunkedInferenceError ;
4344import org .elasticsearch .xpack .inference .mapper .SemanticTextField ;
4445import org .elasticsearch .xpack .inference .mapper .SemanticTextFieldMapper ;
5051import java .util .Collections ;
5152import java .util .Comparator ;
5253import java .util .HashMap ;
54+ import java .util .Iterator ;
5355import java .util .LinkedHashMap ;
5456import java .util .List ;
5557import java .util .Map ;
6769 */
6870public class ShardBulkInferenceActionFilter implements MappedActionFilter {
6971 protected static final int DEFAULT_BATCH_SIZE = 512 ;
72+ private static final Object EXPLICIT_NULL = new Object ();
73+ private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference ();
7074
7175 private final ClusterService clusterService ;
7276 private final InferenceServiceRegistry inferenceServiceRegistry ;
@@ -393,11 +397,22 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
393397 for (var entry : response .responses .entrySet ()) {
394398 var fieldName = entry .getKey ();
395399 var responses = entry .getValue ();
396- var model = responses .get (0 ).model ();
400+ Model model = null ;
401+
402+ InferenceFieldMetadata inferenceFieldMetadata = fieldInferenceMap .get (fieldName );
403+ if (inferenceFieldMetadata == null ) {
404+ throw new IllegalStateException ("No inference field metadata for field [" + fieldName + "]" );
405+ }
406+
397407 // ensure that the order in the original field is consistent in case of multiple inputs
398408 Collections .sort (responses , Comparator .comparingInt (FieldInferenceResponse ::inputOrder ));
399409 Map <String , List <SemanticTextField .Chunk >> chunkMap = new LinkedHashMap <>();
400410 for (var resp : responses ) {
411+ // Get the first non-null model from the response list
412+ if (model == null ) {
413+ model = resp .model ;
414+ }
415+
401416 var lst = chunkMap .computeIfAbsent (resp .sourceField , k -> new ArrayList <>());
402417 lst .addAll (
403418 SemanticTextField .toSemanticTextFieldChunks (
@@ -409,21 +424,26 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
409424 )
410425 );
411426 }
427+
412428 List <String > inputs = responses .stream ()
413429 .filter (r -> r .sourceField ().equals (fieldName ))
414430 .map (r -> r .input )
415431 .collect (Collectors .toList ());
432+
433+ // The model can be null if we are only processing update requests that clear inference results. This is ok because we will
434+ // merge in the field's existing model settings on the data node.
416435 var result = new SemanticTextField (
417436 useLegacyFormat ,
418437 fieldName ,
419438 useLegacyFormat ? inputs : null ,
420439 new SemanticTextField .InferenceResult (
421- model . getInferenceEntityId (),
422- new SemanticTextField .ModelSettings (model ),
440+ inferenceFieldMetadata . getInferenceId (),
441+ model != null ? new SemanticTextField .ModelSettings (model ) : null ,
423442 chunkMap
424443 ),
425444 indexRequest .getContentType ()
426445 );
446+
427447 if (useLegacyFormat ) {
428448 SemanticTextUtils .insertValue (fieldName , newDocMap , result );
429449 } else {
@@ -490,7 +510,8 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
490510 } else {
491511 var inferenceMetadataFieldsValue = XContentMapValues .extractValue (
492512 InferenceMetadataFieldsMapper .NAME + "." + field ,
493- docMap
513+ docMap ,
514+ EXPLICIT_NULL
494515 );
495516 if (inferenceMetadataFieldsValue != null ) {
496517 // Inference has already been computed
@@ -500,9 +521,22 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
500521
501522 int order = 0 ;
502523 for (var sourceField : entry .getSourceFields ()) {
503- // TODO: Detect when the field is provided with an explicit null value
504- var valueObj = XContentMapValues .extractValue (sourceField , docMap );
505- if (valueObj == null ) {
524+ var valueObj = XContentMapValues .extractValue (sourceField , docMap , EXPLICIT_NULL );
525+ if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL ) {
526+ /**
527+ * It's an update request, and the source field is explicitly set to null,
528+ * so we need to propagate this information to the inference fields metadata
529+ * to overwrite any inference previously computed on the field.
530+ * This ensures that the field is treated as intentionally cleared,
531+ * preventing any unintended carryover of prior inference results.
532+ */
533+ var slot = ensureResponseAccumulatorSlot (itemIndex );
534+ slot .addOrUpdateResponse (
535+ new FieldInferenceResponse (field , sourceField , null , order ++, 0 , null , EMPTY_CHUNKED_INFERENCE )
536+ );
537+ continue ;
538+ }
539+ if (valueObj == null || valueObj == EXPLICIT_NULL ) {
506540 if (isUpdateRequest && useLegacyFormat ) {
507541 addInferenceResponseFailure (
508542 item .id (),
@@ -552,4 +586,11 @@ static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) {
552586 return null ;
553587 }
554588 }
589+
590+ private static class EmptyChunkedInference implements ChunkedInference {
591+ @ Override
592+ public Iterator <Chunk > chunksAsMatchedTextAndByteReference (XContent xcontent ) {
593+ return Collections .emptyIterator ();
594+ }
595+ }
555596}
0 commit comments