Skip to content

Commit 45383c8

Browse files
jimczikderusso
andauthored
Add support for highlighting the new format of the semantic text field (#119604)
This change adapts the semantic highlighter to work with the new format introduced in #119183. Co-authored-by: Kathleen DeRusso <[email protected]>
1 parent 3310ca9 commit 45383c8

File tree

8 files changed

+377
-4366
lines changed

8 files changed

+377
-4366
lines changed

server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/DefaultHighlighter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,15 @@ protected static BreakIterator getBreakIterator(SearchHighlightContext.Field fie
227227
}
228228
}
229229

230-
protected static String convertFieldValue(MappedFieldType type, Object value) {
230+
public static String convertFieldValue(MappedFieldType type, Object value) {
231231
if (value instanceof BytesRef) {
232232
return type.valueForDisplay(value).toString();
233233
} else {
234234
return value.toString();
235235
}
236236
}
237237

238-
protected static String mergeFieldValues(List<Object> fieldValues, char valuesSeparator) {
238+
public static String mergeFieldValues(List<Object> fieldValues, char valuesSeparator) {
239239
// postings highlighter accepts all values in a single string, as offsets etc. need to match with content
240240
// loaded from stored fields, we merge all values using a proper separator
241241
String rawValue = Strings.collectionToDelimitedString(fieldValues, String.valueOf(valuesSeparator));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,31 @@
2626
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType;
2727
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType;
2828
import org.elasticsearch.index.query.SearchExecutionContext;
29+
import org.elasticsearch.search.fetch.FetchSubPhase;
30+
import org.elasticsearch.search.fetch.subphase.highlight.DefaultHighlighter;
2931
import org.elasticsearch.search.fetch.subphase.highlight.FieldHighlightContext;
3032
import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
33+
import org.elasticsearch.search.fetch.subphase.highlight.HighlightUtils;
3134
import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;
3235
import org.elasticsearch.search.vectors.VectorData;
3336
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper;
37+
import org.elasticsearch.xpack.inference.mapper.OffsetSourceField;
38+
import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
3439
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
3540
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
41+
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SemanticTextFieldType;
3642

3743
import java.io.IOException;
44+
import java.io.UncheckedIOException;
3845
import java.util.ArrayList;
3946
import java.util.Comparator;
47+
import java.util.HashMap;
4048
import java.util.List;
4149
import java.util.Locale;
4250
import java.util.Map;
51+
import java.util.function.Function;
52+
53+
import static org.elasticsearch.lucene.search.uhighlight.CustomUnifiedHighlighter.MULTIVAL_SEP_CHAR;
4354

4455
/**
4556
* A {@link Highlighter} designed for the {@link SemanticTextFieldMapper}.
@@ -49,20 +60,19 @@
4960
public class SemanticTextHighlighter implements Highlighter {
5061
public static final String NAME = "semantic";
5162

52-
private record OffsetAndScore(int offset, float score) {}
63+
private record OffsetAndScore(int index, OffsetSourceFieldMapper.OffsetSource offset, float score) {}
5364

5465
@Override
5566
public boolean canHighlight(MappedFieldType fieldType) {
56-
if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
57-
// TODO: Implement highlighting when using inference metadata fields
58-
return semanticTextFieldType.useLegacyFormat();
59-
}
60-
return false;
67+
return fieldType instanceof SemanticTextFieldType;
6168
}
6269

6370
@Override
6471
public HighlightField highlight(FieldHighlightContext fieldContext) throws IOException {
65-
SemanticTextFieldMapper.SemanticTextFieldType fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldContext.fieldType;
72+
if (canHighlight(fieldContext.fieldType) == false) {
73+
return null;
74+
}
75+
SemanticTextFieldType fieldType = (SemanticTextFieldType) fieldContext.fieldType;
6676
if (fieldType.getEmbeddingsField() == null) {
6777
// nothing indexed yet
6878
return null;
@@ -105,28 +115,36 @@ public HighlightField highlight(FieldHighlightContext fieldContext) throws IOExc
105115
int size = Math.min(chunks.size(), numberOfFragments);
106116
if (fieldContext.field.fieldOptions().scoreOrdered() == false) {
107117
chunks = chunks.subList(0, size);
108-
chunks.sort(Comparator.comparingInt(c -> c.offset));
118+
chunks.sort(Comparator.comparingInt(c -> c.index));
109119
}
110120
Text[] snippets = new Text[size];
111-
List<Map<?, ?>> nestedSources = XContentMapValues.extractNestedSources(
112-
fieldType.getChunksField().fullPath(),
113-
fieldContext.hitContext.source().source()
114-
);
121+
final Function<OffsetAndScore, String> offsetToContent;
122+
if (fieldType.useLegacyFormat()) {
123+
List<Map<?, ?>> nestedSources = XContentMapValues.extractNestedSources(
124+
fieldType.getChunksField().fullPath(),
125+
fieldContext.hitContext.source().source()
126+
);
127+
offsetToContent = entry -> getContentFromLegacyNestedSources(fieldType.name(), entry, nestedSources);
128+
} else {
129+
Map<String, String> fieldToContent = new HashMap<>();
130+
offsetToContent = entry -> {
131+
String content = fieldToContent.computeIfAbsent(entry.offset().field(), key -> {
132+
try {
133+
return extractFieldContent(
134+
fieldContext.context.getSearchExecutionContext(),
135+
fieldContext.hitContext,
136+
entry.offset.field()
137+
);
138+
} catch (IOException e) {
139+
throw new UncheckedIOException("Error extracting field content from field " + entry.offset.field(), e);
140+
}
141+
});
142+
return content.substring(entry.offset().start(), entry.offset().end());
143+
};
144+
}
115145
for (int i = 0; i < size; i++) {
116146
var chunk = chunks.get(i);
117-
if (nestedSources.size() <= chunk.offset) {
118-
throw new IllegalStateException(
119-
String.format(
120-
Locale.ROOT,
121-
"Invalid content detected for field [%s]: the chunks size is [%d], "
122-
+ "but a reference to offset [%d] was found in the result.",
123-
fieldType.name(),
124-
nestedSources.size(),
125-
chunk.offset
126-
)
127-
);
128-
}
129-
String content = (String) nestedSources.get(chunk.offset).get(SemanticTextField.CHUNKED_TEXT_FIELD);
147+
String content = offsetToContent.apply(chunk);
130148
if (content == null) {
131149
throw new IllegalStateException(
132150
String.format(
@@ -143,10 +161,43 @@ public HighlightField highlight(FieldHighlightContext fieldContext) throws IOExc
143161
return new HighlightField(fieldContext.fieldName, snippets);
144162
}
145163

164+
private String extractFieldContent(SearchExecutionContext searchContext, FetchSubPhase.HitContext hitContext, String sourceField)
165+
throws IOException {
166+
var sourceFieldType = searchContext.getMappingLookup().getFieldType(sourceField);
167+
if (sourceFieldType == null) {
168+
return null;
169+
}
170+
171+
var values = HighlightUtils.loadFieldValues(sourceFieldType, searchContext, hitContext)
172+
.stream()
173+
.<Object>map((s) -> DefaultHighlighter.convertFieldValue(sourceFieldType, s))
174+
.toList();
175+
if (values.size() == 0) {
176+
return null;
177+
}
178+
return DefaultHighlighter.mergeFieldValues(values, MULTIVAL_SEP_CHAR);
179+
}
180+
181+
private String getContentFromLegacyNestedSources(String fieldName, OffsetAndScore cand, List<Map<?, ?>> nestedSources) {
182+
if (nestedSources.size() <= cand.index) {
183+
throw new IllegalStateException(
184+
String.format(
185+
Locale.ROOT,
186+
"Invalid content detected for field [%s]: the chunks size is [%d], "
187+
+ "but a reference to offset [%d] was found in the result.",
188+
fieldName,
189+
nestedSources.size(),
190+
cand.index
191+
)
192+
);
193+
}
194+
return (String) nestedSources.get(cand.index).get(SemanticTextField.CHUNKED_TEXT_FIELD);
195+
}
196+
146197
private List<OffsetAndScore> extractOffsetAndScores(
147198
SearchExecutionContext context,
148199
LeafReader reader,
149-
SemanticTextFieldMapper.SemanticTextFieldType fieldType,
200+
SemanticTextFieldType fieldType,
150201
int docId,
151202
List<Query> leafQueries
152203
) throws IOException {
@@ -164,10 +215,31 @@ private List<OffsetAndScore> extractOffsetAndScores(
164215
} else if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) {
165216
return List.of();
166217
}
218+
219+
OffsetSourceField.OffsetSourceLoader offsetReader = null;
220+
if (fieldType.useLegacyFormat() == false) {
221+
var terms = reader.terms(fieldType.getOffsetsField().fullPath());
222+
if (terms == null) {
223+
// The field is empty
224+
return List.of();
225+
}
226+
offsetReader = OffsetSourceField.loader(terms);
227+
}
228+
167229
List<OffsetAndScore> results = new ArrayList<>();
168-
int offset = 0;
230+
int index = 0;
169231
while (scorer.docID() < docId) {
170-
results.add(new OffsetAndScore(offset++, scorer.score()));
232+
if (offsetReader != null) {
233+
var offset = offsetReader.advanceTo(scorer.docID());
234+
if (offset == null) {
235+
throw new IllegalStateException(
236+
"Cannot highlight field [" + fieldType.name() + "], missing offsets for doc [" + docId + "]"
237+
);
238+
}
239+
results.add(new OffsetAndScore(index++, offset, scorer.score()));
240+
} else {
241+
results.add(new OffsetAndScore(index++, null, scorer.score()));
242+
}
171243
if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) {
172244
break;
173245
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
package org.elasticsearch.xpack.inference.highlight;
99

10+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
11+
1012
import org.apache.lucene.analysis.standard.StandardAnalyzer;
1113
import org.apache.lucene.index.DirectoryReader;
1214
import org.apache.lucene.index.IndexWriterConfig;
@@ -51,7 +53,6 @@
5153
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
5254
import org.elasticsearch.xpack.inference.InferencePlugin;
5355
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
54-
import org.junit.Before;
5556
import org.mockito.Mockito;
5657

5758
import java.io.IOException;
@@ -71,31 +72,35 @@ public class SemanticTextHighlighterTests extends MapperServiceTestCase {
7172
private static final String SEMANTIC_FIELD_E5 = "body-e5";
7273
private static final String SEMANTIC_FIELD_ELSER = "body-elser";
7374

74-
private Map<String, Object> queries;
75+
private final boolean useLegacyFormat;
76+
private final Map<String, Object> queries;
7577

76-
@Override
77-
protected Collection<? extends Plugin> getPlugins() {
78-
return List.of(new InferencePlugin(Settings.EMPTY));
78+
public SemanticTextHighlighterTests(boolean useLegacyFormat) throws IOException {
79+
this.useLegacyFormat = useLegacyFormat;
80+
var input = Streams.readFully(SemanticTextHighlighterTests.class.getResourceAsStream("queries.json"));
81+
this.queries = XContentHelper.convertToMap(input, false, XContentType.JSON).v2();
82+
}
83+
84+
@ParametersFactory
85+
public static Iterable<Object[]> parameters() throws Exception {
86+
return List.of(new Object[] { true }, new Object[] { false });
7987
}
8088

8189
@Override
82-
@Before
83-
public void setUp() throws Exception {
84-
super.setUp();
85-
var input = Streams.readFully(SemanticTextHighlighterTests.class.getResourceAsStream("queries.json"));
86-
this.queries = XContentHelper.convertToMap(input, false, XContentType.JSON).v2();
90+
protected Collection<? extends Plugin> getPlugins() {
91+
return List.of(new InferencePlugin(Settings.EMPTY));
8792
}
8893

8994
@SuppressWarnings("unchecked")
9095
public void testDenseVector() throws Exception {
91-
var mapperService = createDefaultMapperService();
96+
var mapperService = createDefaultMapperService(useLegacyFormat);
9297
Map<String, Object> queryMap = (Map<String, Object>) queries.get("dense_vector_1");
9398
float[] vector = readDenseVector(queryMap.get("embeddings"));
9499
var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_E5);
95100
KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(fieldType.getEmbeddingsField().fullPath(), vector, 10, 10, null, null);
96101
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), knnQuery, ScoreMode.Max);
97102
var shardRequest = createShardSearchRequest(nestedQueryBuilder);
98-
var sourceToParse = new SourceToParse("0", readSampleDoc("sample-doc.json.gz"), XContentType.JSON);
103+
var sourceToParse = new SourceToParse("0", readSampleDoc(useLegacyFormat), XContentType.JSON);
99104

100105
String[] expectedScorePassages = ((List<String>) queryMap.get("expected_by_score")).toArray(String[]::new);
101106
for (int i = 0; i < expectedScorePassages.length; i++) {
@@ -124,7 +129,7 @@ public void testDenseVector() throws Exception {
124129

125130
@SuppressWarnings("unchecked")
126131
public void testSparseVector() throws Exception {
127-
var mapperService = createDefaultMapperService();
132+
var mapperService = createDefaultMapperService(useLegacyFormat);
128133
Map<String, Object> queryMap = (Map<String, Object>) queries.get("sparse_vector_1");
129134
List<WeightedToken> tokens = readSparseVector(queryMap.get("embeddings"));
130135
var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_ELSER);
@@ -138,7 +143,7 @@ public void testSparseVector() throws Exception {
138143
);
139144
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), sparseQuery, ScoreMode.Max);
140145
var shardRequest = createShardSearchRequest(nestedQueryBuilder);
141-
var sourceToParse = new SourceToParse("0", readSampleDoc("sample-doc.json.gz"), XContentType.JSON);
146+
var sourceToParse = new SourceToParse("0", readSampleDoc(useLegacyFormat), XContentType.JSON);
142147

143148
String[] expectedScorePassages = ((List<String>) queryMap.get("expected_by_score")).toArray(String[]::new);
144149
for (int i = 0; i < expectedScorePassages.length; i++) {
@@ -165,9 +170,11 @@ public void testSparseVector() throws Exception {
165170
);
166171
}
167172

168-
private MapperService createDefaultMapperService() throws IOException {
173+
private MapperService createDefaultMapperService(boolean useLegacyFormat) throws IOException {
169174
var mappings = Streams.readFully(SemanticTextHighlighterTests.class.getResourceAsStream("mappings.json"));
170-
var settings = Settings.builder().put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), true).build();
175+
var settings = Settings.builder()
176+
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat)
177+
.build();
171178
return createMapperService(settings, mappings.utf8ToString());
172179
}
173180

@@ -282,7 +289,8 @@ private ShardSearchRequest createShardSearchRequest(QueryBuilder queryBuilder) {
282289
return new ShardSearchRequest(OriginalIndices.NONE, request, new ShardId("index", "index", 0), 0, 1, AliasFilter.EMPTY, 1, 0, null);
283290
}
284291

285-
private BytesReference readSampleDoc(String fileName) throws IOException {
292+
private BytesReference readSampleDoc(boolean useLegacyFormat) throws IOException {
293+
String fileName = useLegacyFormat ? "sample-doc-legacy.json.gz" : "sample-doc.json.gz";
286294
try (var in = new GZIPInputStream(SemanticTextHighlighterTests.class.getResourceAsStream(fileName))) {
287295
return new BytesArray(new BytesRef(in.readAllBytes()));
288296
}

0 commit comments

Comments
 (0)