Skip to content

Commit 534a96a

Browse files
committed
Add Highlighter for Semantic Text Fields
This PR introduces a new highlighter, `semantic`, tailored for semantic text fields. It extracts the most relevant fragments by scoring nested chunks using the original semantic query. In this initial version, the highlighter returns only the original chunks computed during ingestion. However, this is an implementation detail, and future enhancements could combine multiple chunks to generate the fragments.
1 parent 2fe6b60 commit 534a96a

File tree

9 files changed

+1026
-43
lines changed

9 files changed

+1026
-43
lines changed

docs/reference/mapping/types/semantic-text.asciidoc

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -112,50 +112,41 @@ Trying to <<delete-inference-api,delete an {infer} endpoint>> that is used on a
112112
{infer-cap} endpoints have a limit on the amount of text they can process.
113113
To allow for large amounts of text to be used in semantic search, `semantic_text` automatically generates smaller passages if needed, called _chunks_.
114114

115-
Each chunk will include the text subpassage and the corresponding embedding generated from it.
115+
Each chunk refers to a passage of the text and the corresponding embedding generated from it.
116116
When querying, the individual passages will be automatically searched for each document, and the most relevant passage will be used to compute a score.
117117

118118
For more details on chunking and how to configure chunking settings, see <<infer-chunking-config, Configuring chunking>> in the Inference API documentation.
119119

120+
Refer to <<semantic-search-semantic-text,this tutorial>> to learn more about
121+
semantic search using `semantic_text` and the `semantic` query.
120122

121123
[discrete]
122-
[[semantic-text-structure]]
123-
==== `semantic_text` structure
124+
[[semantic-text-highlighting]]
125+
==== Extracting Relevant Fragments from Semantic Text
124126

125-
Once a document is ingested, a `semantic_text` field will have the following structure:
127+
You can extract the most relevant fragments from a semantic text field by using the <<highlighting,highlight parameter>> in the <<search-search-api-request-body,Search API>>.
126128

127-
[source,console-result]
129+
[source,console]
128130
------------------------------------------------------------
129-
"inference_field": {
130-
"text": "these are not the droids you're looking for", <1>
131-
"inference": {
132-
"inference_id": "my-elser-endpoint", <2>
133-
"model_settings": { <3>
134-
"task_type": "sparse_embedding"
131+
PUT test-index
132+
{
133+
"query": {
134+
"semantic": {
135+
"field": "my_semantic_field"
136+
}
135137
},
136-
"chunks": [ <4>
137-
{
138-
"text": "these are not the droids you're looking for",
139-
"embeddings": {
140-
(...)
138+
"highlight": {
139+
"fields": {
140+
"my_semantic_field": {
141+
"type": "semantic",
142+
"number_of_fragments": 2 <1>
143+
}
141144
}
142-
}
143-
]
144-
}
145+
}
145146
}
146147
------------------------------------------------------------
147-
// TEST[skip:TBD]
148-
<1> The field will become an object structure to accommodate both the original
149-
text and the inference results.
150-
<2> The `inference_id` used to generate the embeddings.
151-
<3> Model settings, including the task type and dimensions/similarity if
152-
applicable.
153-
<4> Inference results will be grouped in chunks, each with its corresponding
154-
text and embeddings.
155-
156-
Refer to <<semantic-search-semantic-text,this tutorial>> to learn more about
157-
semantic search using `semantic_text` and the `semantic` query.
158-
148+
// TEST[skip:Requires inference endpoint]
149+
<1> Specifies the maximum number of fragments to return.
159150

160151
[discrete]
161152
[[custom-indexing]]

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.elasticsearch.plugins.SystemIndexPlugin;
3838
import org.elasticsearch.rest.RestController;
3939
import org.elasticsearch.rest.RestHandler;
40+
import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;
4041
import org.elasticsearch.search.rank.RankBuilder;
4142
import org.elasticsearch.search.rank.RankDoc;
4243
import org.elasticsearch.threadpool.ExecutorBuilder;
@@ -67,6 +68,7 @@
6768
import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings;
6869
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
6970
import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings;
71+
import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter;
7072
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
7173
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
7274
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
@@ -411,4 +413,9 @@ public List<RetrieverSpec<?>> getRetrievers() {
411413
new RetrieverSpec<>(new ParseField(RandomRankBuilder.NAME), RandomRankRetrieverBuilder::fromXContent)
412414
);
413415
}
416+
417+
@Override
418+
public Map<String, Highlighter> getHighlighters() {
419+
return Map.of(SemanticTextHighlighter.NAME, new SemanticTextHighlighter());
420+
}
414421
}
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.highlight;
9+
10+
import org.apache.lucene.index.LeafReader;
11+
import org.apache.lucene.index.Term;
12+
import org.apache.lucene.search.BooleanClause;
13+
import org.apache.lucene.search.BooleanQuery;
14+
import org.apache.lucene.search.DocIdSetIterator;
15+
import org.apache.lucene.search.IndexSearcher;
16+
import org.apache.lucene.search.KnnByteVectorQuery;
17+
import org.apache.lucene.search.KnnFloatVectorQuery;
18+
import org.apache.lucene.search.Query;
19+
import org.apache.lucene.search.QueryVisitor;
20+
import org.apache.lucene.search.ScoreMode;
21+
import org.apache.lucene.search.Scorer;
22+
import org.apache.lucene.search.Weight;
23+
import org.elasticsearch.common.text.Text;
24+
import org.elasticsearch.common.xcontent.support.XContentMapValues;
25+
import org.elasticsearch.index.mapper.MappedFieldType;
26+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType;
27+
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType;
28+
import org.elasticsearch.index.query.SearchExecutionContext;
29+
import org.elasticsearch.search.fetch.subphase.highlight.FieldHighlightContext;
30+
import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
31+
import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;
32+
import org.elasticsearch.search.vectors.VectorData;
33+
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper;
34+
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
35+
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
36+
37+
import java.io.IOException;
38+
import java.util.ArrayList;
39+
import java.util.Comparator;
40+
import java.util.List;
41+
import java.util.Map;
42+
43+
/**
44+
* A {@link Highlighter} designed for the {@link SemanticTextFieldMapper}.
45+
* This highlighter extracts semantic queries and evaluates them against each chunk produced by the semantic text field.
46+
* It returns the top-scoring chunks as snippets, optionally sorted by their scores.
47+
*/
48+
public class SemanticTextHighlighter implements Highlighter {
49+
public static final String NAME = "semantic";
50+
51+
private record OffsetAndScore(int offset, float score) {}
52+
53+
@Override
54+
public boolean canHighlight(MappedFieldType fieldType) {
55+
if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType) {
56+
return true;
57+
}
58+
return false;
59+
}
60+
61+
@Override
62+
public HighlightField highlight(FieldHighlightContext fieldContext) throws IOException {
63+
SemanticTextFieldMapper.SemanticTextFieldType fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldContext.fieldType;
64+
if (fieldType.getEmbeddingsField() == null) {
65+
// nothing indexed yet
66+
return null;
67+
}
68+
69+
final List<Query> queries = switch (fieldType.getModelSettings().taskType()) {
70+
case SPARSE_EMBEDDING -> extractSparseVectorQueries(
71+
(SparseVectorFieldType) fieldType.getEmbeddingsField().fieldType(),
72+
fieldContext.query
73+
);
74+
case TEXT_EMBEDDING -> extractDenseVectorQueries(
75+
(DenseVectorFieldType) fieldType.getEmbeddingsField().fieldType(),
76+
fieldContext.query
77+
);
78+
default -> throw new IllegalStateException(
79+
"Wrong task type for a semantic text field, got [" + fieldType.getModelSettings().taskType().name() + "]"
80+
);
81+
};
82+
if (queries.isEmpty()) {
83+
// nothing to highlight
84+
return null;
85+
}
86+
87+
int numberOfFragments = fieldContext.field.fieldOptions().numberOfFragments() <= 0
88+
? 1 // we return the best fragment by default
89+
: fieldContext.field.fieldOptions().numberOfFragments();
90+
91+
List<OffsetAndScore> chunks = extractOffsetAndScores(
92+
fieldContext.context.getSearchExecutionContext(),
93+
fieldContext.hitContext.reader(),
94+
fieldType,
95+
fieldContext.hitContext.docId(),
96+
queries
97+
);
98+
if (chunks.size() == 0) {
99+
return null;
100+
}
101+
102+
chunks.sort(Comparator.comparingDouble(OffsetAndScore::score).reversed());
103+
int size = Math.min(chunks.size(), numberOfFragments);
104+
if (fieldContext.field.fieldOptions().scoreOrdered() == false) {
105+
chunks = chunks.subList(0, size);
106+
chunks.sort(Comparator.comparingInt(c -> c.offset));
107+
}
108+
Text[] snippets = new Text[size];
109+
List<Map<?, ?>> nestedSources = XContentMapValues.extractNestedSources(
110+
fieldType.getChunksField().fullPath(),
111+
fieldContext.hitContext.source().source()
112+
);
113+
for (int i = 0; i < size; i++) {
114+
var chunk = chunks.get(i);
115+
if (nestedSources.size() <= chunk.offset) {
116+
throw new IllegalStateException("Invalid content for field [" + fieldType.name() + "]");
117+
}
118+
String content = (String) nestedSources.get(chunk.offset).get(SemanticTextField.CHUNKED_TEXT_FIELD);
119+
if (content == null) {
120+
throw new IllegalStateException("Invalid content for field [" + fieldType.name() + "]");
121+
}
122+
snippets[i] = new Text(content);
123+
}
124+
return new HighlightField(fieldContext.fieldName, snippets);
125+
}
126+
127+
private List<OffsetAndScore> extractOffsetAndScores(
128+
SearchExecutionContext context,
129+
LeafReader reader,
130+
SemanticTextFieldMapper.SemanticTextFieldType fieldType,
131+
int docId,
132+
List<Query> leafQueries
133+
) throws IOException {
134+
var bitSet = context.bitsetFilter(fieldType.getChunksField().parentTypeFilter()).getBitSet(reader.getContext());
135+
int previousParent = docId > 0 ? bitSet.prevSetBit(docId - 1) : -1;
136+
137+
BooleanQuery.Builder bq = new BooleanQuery.Builder().add(fieldType.getChunksField().nestedTypeFilter(), BooleanClause.Occur.FILTER);
138+
leafQueries.stream().forEach(q -> bq.add(q, BooleanClause.Occur.SHOULD));
139+
Weight weight = new IndexSearcher(reader).createWeight(bq.build(), ScoreMode.COMPLETE, 1);
140+
Scorer scorer = weight.scorer(reader.getContext());
141+
if (previousParent != -1) {
142+
if (scorer.iterator().advance(previousParent) == DocIdSetIterator.NO_MORE_DOCS) {
143+
return List.of();
144+
}
145+
} else if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) {
146+
return List.of();
147+
}
148+
List<OffsetAndScore> results = new ArrayList<>();
149+
int offset = 0;
150+
while (scorer.docID() < docId) {
151+
results.add(new OffsetAndScore(offset++, scorer.score()));
152+
if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) {
153+
break;
154+
}
155+
}
156+
return results;
157+
}
158+
159+
private List<Query> extractDenseVectorQueries(DenseVectorFieldType fieldType, Query querySection) {
160+
// TODO: Handle knn section when semantic text field can be used.
161+
List<Query> queries = new ArrayList<>();
162+
querySection.visit(new QueryVisitor() {
163+
@Override
164+
public boolean acceptField(String field) {
165+
return fieldType.name().equals(field);
166+
}
167+
168+
@Override
169+
public void consumeTerms(Query query, Term... terms) {
170+
super.consumeTerms(query, terms);
171+
}
172+
173+
@Override
174+
public void visitLeaf(Query query) {
175+
if (query instanceof KnnFloatVectorQuery knnQuery) {
176+
queries.add(fieldType.createExactKnnQuery(VectorData.fromFloats(knnQuery.getTargetCopy()), null));
177+
} else if (query instanceof KnnByteVectorQuery knnQuery) {
178+
queries.add(fieldType.createExactKnnQuery(VectorData.fromBytes(knnQuery.getTargetCopy()), null));
179+
}
180+
}
181+
});
182+
return queries;
183+
}
184+
185+
private List<Query> extractSparseVectorQueries(SparseVectorFieldType fieldType, Query querySection) {
186+
List<Query> queries = new ArrayList<>();
187+
querySection.visit(new QueryVisitor() {
188+
@Override
189+
public boolean acceptField(String field) {
190+
return fieldType.name().equals(field);
191+
}
192+
193+
@Override
194+
public void consumeTerms(Query query, Term... terms) {
195+
super.consumeTerms(query, terms);
196+
}
197+
198+
@Override
199+
public QueryVisitor getSubVisitor(BooleanClause.Occur occur, Query parent) {
200+
if (parent instanceof SparseVectorQueryWrapper sparseVectorQuery) {
201+
queries.add(sparseVectorQuery.getTermsQuery());
202+
}
203+
return this;
204+
}
205+
});
206+
return queries;
207+
}
208+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public record SemanticTextField(String fieldName, List<String> originalValues, I
6161
static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id";
6262
static final String CHUNKS_FIELD = "chunks";
6363
static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings";
64-
static final String CHUNKED_TEXT_FIELD = "text";
64+
public static final String CHUNKED_TEXT_FIELD = "text";
6565
static final String MODEL_SETTINGS_FIELD = "model_settings";
6666
static final String TASK_TYPE_FIELD = "task_type";
6767
static final String DIMENSIONS_FIELD = "dimensions";

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
4747
import org.elasticsearch.index.query.NestedQueryBuilder;
4848
import org.elasticsearch.index.query.QueryBuilder;
49-
import org.elasticsearch.index.query.QueryBuilders;
5049
import org.elasticsearch.index.query.SearchExecutionContext;
5150
import org.elasticsearch.inference.InferenceResults;
5251
import org.elasticsearch.inference.SimilarityMeasure;
@@ -57,6 +56,7 @@
5756
import org.elasticsearch.xcontent.XContentParserConfiguration;
5857
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
5958
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
59+
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
6060

6161
import java.io.IOException;
6262
import java.util.ArrayList;
@@ -529,17 +529,8 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, Integer req
529529
);
530530
}
531531

532-
// TODO: Use WeightedTokensQueryBuilder
533532
TextExpansionResults textExpansionResults = (TextExpansionResults) inferenceResults;
534-
var boolQuery = QueryBuilders.boolQuery();
535-
for (var weightedToken : textExpansionResults.getWeightedTokens()) {
536-
boolQuery.should(
537-
QueryBuilders.termQuery(inferenceResultsFieldName, weightedToken.token()).boost(weightedToken.weight())
538-
);
539-
}
540-
boolQuery.minimumShouldMatch(1);
541-
542-
yield boolQuery;
533+
yield new SparseVectorQueryBuilder(name(), textExpansionResults.getWeightedTokens(), null, null, null, null);
543534
}
544535
case TEXT_EMBEDDING -> {
545536
if (inferenceResults instanceof MlTextEmbeddingResults == false) {

0 commit comments

Comments
 (0)