Skip to content

Commit 7d0e861

Browse files
authored
Semantic text - Clear inference results on explicit nulls (#119463) (#119515)
Fix a bug where setting a semantic_text source field explicitly to null in an update request to clear inference results did not actually clear the inference results for that field. This bug only affects the new _inference_fields format.
1 parent 0ed520b commit 7d0e861

File tree

9 files changed

+461
-96
lines changed

9 files changed

+461
-96
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.features.FeatureSpecification;
1111
import org.elasticsearch.features.NodeFeature;
12+
import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsMapper;
1213
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
1314
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
1415
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
@@ -46,7 +47,8 @@ public Set<NodeFeature> getTestFeatures() {
4647
SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX,
4748
SEMANTIC_TEXT_HIGHLIGHTER,
4849
SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
49-
SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED
50+
SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
51+
SemanticInferenceMetadataFieldsMapper.EXPLICIT_NULL_FIXES
5052
);
5153
}
5254
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.elasticsearch.inference.UnparsedModel;
4040
import org.elasticsearch.rest.RestStatus;
4141
import org.elasticsearch.tasks.Task;
42+
import org.elasticsearch.xcontent.XContent;
4243
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
4344
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
4445
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
@@ -50,6 +51,7 @@
5051
import java.util.Collections;
5152
import java.util.Comparator;
5253
import java.util.HashMap;
54+
import java.util.Iterator;
5355
import java.util.LinkedHashMap;
5456
import java.util.List;
5557
import java.util.Map;
@@ -67,6 +69,8 @@
6769
*/
6870
public class ShardBulkInferenceActionFilter implements MappedActionFilter {
6971
protected static final int DEFAULT_BATCH_SIZE = 512;
72+
private static final Object EXPLICIT_NULL = new Object();
73+
private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference();
7074

7175
private final ClusterService clusterService;
7276
private final InferenceServiceRegistry inferenceServiceRegistry;
@@ -393,11 +397,22 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
393397
for (var entry : response.responses.entrySet()) {
394398
var fieldName = entry.getKey();
395399
var responses = entry.getValue();
396-
var model = responses.get(0).model();
400+
Model model = null;
401+
402+
InferenceFieldMetadata inferenceFieldMetadata = fieldInferenceMap.get(fieldName);
403+
if (inferenceFieldMetadata == null) {
404+
throw new IllegalStateException("No inference field metadata for field [" + fieldName + "]");
405+
}
406+
397407
// ensure that the order in the original field is consistent in case of multiple inputs
398408
Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder));
399409
Map<String, List<SemanticTextField.Chunk>> chunkMap = new LinkedHashMap<>();
400410
for (var resp : responses) {
411+
// Get the first non-null model from the response list
412+
if (model == null) {
413+
model = resp.model;
414+
}
415+
401416
var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>());
402417
lst.addAll(
403418
SemanticTextField.toSemanticTextFieldChunks(
@@ -409,21 +424,26 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
409424
)
410425
);
411426
}
427+
412428
List<String> inputs = responses.stream()
413429
.filter(r -> r.sourceField().equals(fieldName))
414430
.map(r -> r.input)
415431
.collect(Collectors.toList());
432+
433+
// The model can be null if we are only processing update requests that clear inference results. This is ok because we will
434+
// merge in the field's existing model settings on the data node.
416435
var result = new SemanticTextField(
417436
useLegacyFormat,
418437
fieldName,
419438
useLegacyFormat ? inputs : null,
420439
new SemanticTextField.InferenceResult(
421-
model.getInferenceEntityId(),
422-
new SemanticTextField.ModelSettings(model),
440+
inferenceFieldMetadata.getInferenceId(),
441+
model != null ? new SemanticTextField.ModelSettings(model) : null,
423442
chunkMap
424443
),
425444
indexRequest.getContentType()
426445
);
446+
427447
if (useLegacyFormat) {
428448
SemanticTextUtils.insertValue(fieldName, newDocMap, result);
429449
} else {
@@ -490,7 +510,8 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
490510
} else {
491511
var inferenceMetadataFieldsValue = XContentMapValues.extractValue(
492512
InferenceMetadataFieldsMapper.NAME + "." + field,
493-
docMap
513+
docMap,
514+
EXPLICIT_NULL
494515
);
495516
if (inferenceMetadataFieldsValue != null) {
496517
// Inference has already been computed
@@ -500,9 +521,22 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
500521

501522
int order = 0;
502523
for (var sourceField : entry.getSourceFields()) {
503-
// TODO: Detect when the field is provided with an explicit null value
504-
var valueObj = XContentMapValues.extractValue(sourceField, docMap);
505-
if (valueObj == null) {
524+
var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL);
525+
if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) {
526+
/**
527+
* It's an update request, and the source field is explicitly set to null,
528+
* so we need to propagate this information to the inference fields metadata
529+
* to overwrite any inference previously computed on the field.
530+
* This ensures that the field is treated as intentionally cleared,
531+
* preventing any unintended carryover of prior inference results.
532+
*/
533+
var slot = ensureResponseAccumulatorSlot(itemIndex);
534+
slot.addOrUpdateResponse(
535+
new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
536+
);
537+
continue;
538+
}
539+
if (valueObj == null || valueObj == EXPLICIT_NULL) {
506540
if (isUpdateRequest && useLegacyFormat) {
507541
addInferenceResponseFailure(
508542
item.id(),
@@ -552,4 +586,11 @@ static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) {
552586
return null;
553587
}
554588
}
589+
590+
private static class EmptyChunkedInference implements ChunkedInference {
591+
@Override
592+
public Iterator<Chunk> chunksAsMatchedTextAndByteReference(XContent xcontent) {
593+
return Collections.emptyIterator();
594+
}
595+
}
555596
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsMapper.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.lucene.search.Query;
1313
import org.apache.lucene.search.join.BitSetProducer;
1414
import org.elasticsearch.common.xcontent.XContentParserUtils;
15+
import org.elasticsearch.features.NodeFeature;
1516
import org.elasticsearch.index.mapper.ContentPath;
1617
import org.elasticsearch.index.mapper.DocumentParserContext;
1718
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
@@ -38,6 +39,8 @@
3839
public class SemanticInferenceMetadataFieldsMapper extends InferenceMetadataFieldsMapper {
3940
private static final SemanticInferenceMetadataFieldsMapper INSTANCE = new SemanticInferenceMetadataFieldsMapper();
4041

42+
public static final NodeFeature EXPLICIT_NULL_FIXES = new NodeFeature("semantic_text.inference_metadata_fields.explicit_null_fixes");
43+
4144
public static final TypeParser PARSER = new FixedTypeParser(
4245
c -> InferenceMetadataFieldsMapper.isEnabled(c.getSettings()) ? INSTANCE : null
4346
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,16 +338,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
338338

339339
static {
340340
SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), new ParseField(TEXT_FIELD));
341-
SEMANTIC_TEXT_FIELD_PARSER.declareObject(
342-
constructorArg(),
343-
(p, c) -> INFERENCE_RESULT_PARSER.parse(p, c),
344-
new ParseField(INFERENCE_FIELD)
345-
);
341+
SEMANTIC_TEXT_FIELD_PARSER.declareObject(constructorArg(), INFERENCE_RESULT_PARSER, new ParseField(INFERENCE_FIELD));
346342

347343
INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD));
348-
INFERENCE_RESULT_PARSER.declareObject(
344+
INFERENCE_RESULT_PARSER.declareObjectOrNull(
349345
constructorArg(),
350346
(p, c) -> MODEL_SETTINGS_PARSER.parse(p, null),
347+
null,
351348
new ParseField(MODEL_SETTINGS_FIELD)
352349
);
353350
INFERENCE_RESULT_PARSER.declareField(constructorArg(), (p, c) -> {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,17 @@ void parseCreateFieldFromContext(DocumentParserContext context, SemanticTextFiel
384384
mapper = this;
385385
}
386386

387+
if (mapper.fieldType().getModelSettings() == null) {
388+
for (var chunkList : field.inference().chunks().values()) {
389+
if (chunkList.isEmpty() == false) {
390+
throw new DocumentParsingException(
391+
xContentLocation,
392+
"[" + MODEL_SETTINGS_FIELD + "] must be set for field [" + fullFieldName + "] when chunks are provided"
393+
);
394+
}
395+
}
396+
}
397+
387398
var chunksField = mapper.fieldType().getChunksField();
388399
var embeddingsField = mapper.fieldType().getEmbeddingsField();
389400
var offsetsField = mapper.fieldType().getOffsetsField();
@@ -895,7 +906,7 @@ private static boolean canMergeModelSettings(
895906
if (Objects.equals(previous, current)) {
896907
return true;
897908
}
898-
if (previous == null) {
909+
if (previous == null || current == null) {
899910
return true;
900911
}
901912
conflicts.addConflict("model_settings", "");

0 commit comments

Comments
 (0)