|
| 1 | +/* |
| 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one |
| 3 | + * or more contributor license agreements. Licensed under the Elastic License |
| 4 | + * 2.0; you may not use this file except in compliance with the Elastic License |
| 5 | + * 2.0. |
| 6 | + */ |
| 7 | + |
| 8 | +package org.elasticsearch.xpack.inference.highlight; |
| 9 | + |
| 10 | +import org.apache.lucene.index.LeafReader; |
| 11 | +import org.apache.lucene.index.Term; |
| 12 | +import org.apache.lucene.search.BooleanClause; |
| 13 | +import org.apache.lucene.search.BooleanQuery; |
| 14 | +import org.apache.lucene.search.DocIdSetIterator; |
| 15 | +import org.apache.lucene.search.IndexSearcher; |
| 16 | +import org.apache.lucene.search.KnnByteVectorQuery; |
| 17 | +import org.apache.lucene.search.KnnFloatVectorQuery; |
| 18 | +import org.apache.lucene.search.Query; |
| 19 | +import org.apache.lucene.search.QueryVisitor; |
| 20 | +import org.apache.lucene.search.ScoreMode; |
| 21 | +import org.apache.lucene.search.Scorer; |
| 22 | +import org.apache.lucene.search.Weight; |
| 23 | +import org.elasticsearch.common.text.Text; |
| 24 | +import org.elasticsearch.common.xcontent.support.XContentMapValues; |
| 25 | +import org.elasticsearch.index.mapper.MappedFieldType; |
| 26 | +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType; |
| 27 | +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType; |
| 28 | +import org.elasticsearch.index.query.SearchExecutionContext; |
| 29 | +import org.elasticsearch.search.fetch.subphase.highlight.FieldHighlightContext; |
| 30 | +import org.elasticsearch.search.fetch.subphase.highlight.HighlightField; |
| 31 | +import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; |
| 32 | +import org.elasticsearch.search.vectors.VectorData; |
| 33 | +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper; |
| 34 | +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; |
| 35 | +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; |
| 36 | + |
| 37 | +import java.io.IOException; |
| 38 | +import java.util.ArrayList; |
| 39 | +import java.util.Comparator; |
| 40 | +import java.util.List; |
| 41 | +import java.util.Map; |
| 42 | + |
| 43 | +/** |
| 44 | + * A {@link Highlighter} designed for the {@link SemanticTextFieldMapper}. |
| 45 | + * This highlighter extracts semantic queries and evaluates them against each chunk produced by the semantic text field. |
| 46 | + * It returns the top-scoring chunks as snippets, optionally sorted by their scores. |
| 47 | + */ |
| 48 | +public class SemanticTextHighlighter implements Highlighter { |
| 49 | + public static final String NAME = "semantic"; |
| 50 | + |
| 51 | + private record OffsetAndScore(int offset, float score) {} |
| 52 | + |
| 53 | + @Override |
| 54 | + public boolean canHighlight(MappedFieldType fieldType) { |
| 55 | + if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType) { |
| 56 | + return true; |
| 57 | + } |
| 58 | + return false; |
| 59 | + } |
| 60 | + |
| 61 | + @Override |
| 62 | + public HighlightField highlight(FieldHighlightContext fieldContext) throws IOException { |
| 63 | + SemanticTextFieldMapper.SemanticTextFieldType fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldContext.fieldType; |
| 64 | + if (fieldType.getEmbeddingsField() == null) { |
| 65 | + // nothing indexed yet |
| 66 | + return null; |
| 67 | + } |
| 68 | + |
| 69 | + final List<Query> queries = switch (fieldType.getModelSettings().taskType()) { |
| 70 | + case SPARSE_EMBEDDING -> extractSparseVectorQueries( |
| 71 | + (SparseVectorFieldType) fieldType.getEmbeddingsField().fieldType(), |
| 72 | + fieldContext.query |
| 73 | + ); |
| 74 | + case TEXT_EMBEDDING -> extractDenseVectorQueries( |
| 75 | + (DenseVectorFieldType) fieldType.getEmbeddingsField().fieldType(), |
| 76 | + fieldContext.query |
| 77 | + ); |
| 78 | + default -> throw new IllegalStateException( |
| 79 | + "Wrong task type for a semantic text field, got [" + fieldType.getModelSettings().taskType().name() + "]" |
| 80 | + ); |
| 81 | + }; |
| 82 | + if (queries.isEmpty()) { |
| 83 | + // nothing to highlight |
| 84 | + return null; |
| 85 | + } |
| 86 | + |
| 87 | + int numberOfFragments = fieldContext.field.fieldOptions().numberOfFragments() <= 0 |
| 88 | + ? 1 // we return the best fragment by default |
| 89 | + : fieldContext.field.fieldOptions().numberOfFragments(); |
| 90 | + |
| 91 | + List<OffsetAndScore> chunks = extractOffsetAndScores( |
| 92 | + fieldContext.context.getSearchExecutionContext(), |
| 93 | + fieldContext.hitContext.reader(), |
| 94 | + fieldType, |
| 95 | + fieldContext.hitContext.docId(), |
| 96 | + queries |
| 97 | + ); |
| 98 | + if (chunks.size() == 0) { |
| 99 | + return null; |
| 100 | + } |
| 101 | + |
| 102 | + chunks.sort(Comparator.comparingDouble(OffsetAndScore::score).reversed()); |
| 103 | + int size = Math.min(chunks.size(), numberOfFragments); |
| 104 | + if (fieldContext.field.fieldOptions().scoreOrdered() == false) { |
| 105 | + chunks = chunks.subList(0, size); |
| 106 | + chunks.sort(Comparator.comparingInt(c -> c.offset)); |
| 107 | + } |
| 108 | + Text[] snippets = new Text[size]; |
| 109 | + List<Map<?, ?>> nestedSources = XContentMapValues.extractNestedSources( |
| 110 | + fieldType.getChunksField().fullPath(), |
| 111 | + fieldContext.hitContext.source().source() |
| 112 | + ); |
| 113 | + for (int i = 0; i < size; i++) { |
| 114 | + var chunk = chunks.get(i); |
| 115 | + if (nestedSources.size() <= chunk.offset) { |
| 116 | + throw new IllegalStateException("Invalid content for field [" + fieldType.name() + "]"); |
| 117 | + } |
| 118 | + String content = (String) nestedSources.get(chunk.offset).get(SemanticTextField.CHUNKED_TEXT_FIELD); |
| 119 | + if (content == null) { |
| 120 | + throw new IllegalStateException("Invalid content for field [" + fieldType.name() + "]"); |
| 121 | + } |
| 122 | + snippets[i] = new Text(content); |
| 123 | + } |
| 124 | + return new HighlightField(fieldContext.fieldName, snippets); |
| 125 | + } |
| 126 | + |
| 127 | + private List<OffsetAndScore> extractOffsetAndScores( |
| 128 | + SearchExecutionContext context, |
| 129 | + LeafReader reader, |
| 130 | + SemanticTextFieldMapper.SemanticTextFieldType fieldType, |
| 131 | + int docId, |
| 132 | + List<Query> leafQueries |
| 133 | + ) throws IOException { |
| 134 | + var bitSet = context.bitsetFilter(fieldType.getChunksField().parentTypeFilter()).getBitSet(reader.getContext()); |
| 135 | + int previousParent = docId > 0 ? bitSet.prevSetBit(docId - 1) : -1; |
| 136 | + |
| 137 | + BooleanQuery.Builder bq = new BooleanQuery.Builder().add(fieldType.getChunksField().nestedTypeFilter(), BooleanClause.Occur.FILTER); |
| 138 | + leafQueries.stream().forEach(q -> bq.add(q, BooleanClause.Occur.SHOULD)); |
| 139 | + Weight weight = new IndexSearcher(reader).createWeight(bq.build(), ScoreMode.COMPLETE, 1); |
| 140 | + Scorer scorer = weight.scorer(reader.getContext()); |
| 141 | + if (previousParent != -1) { |
| 142 | + if (scorer.iterator().advance(previousParent) == DocIdSetIterator.NO_MORE_DOCS) { |
| 143 | + return List.of(); |
| 144 | + } |
| 145 | + } else if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) { |
| 146 | + return List.of(); |
| 147 | + } |
| 148 | + List<OffsetAndScore> results = new ArrayList<>(); |
| 149 | + int offset = 0; |
| 150 | + while (scorer.docID() < docId) { |
| 151 | + results.add(new OffsetAndScore(offset++, scorer.score())); |
| 152 | + if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) { |
| 153 | + break; |
| 154 | + } |
| 155 | + } |
| 156 | + return results; |
| 157 | + } |
| 158 | + |
| 159 | + private List<Query> extractDenseVectorQueries(DenseVectorFieldType fieldType, Query querySection) { |
| 160 | + // TODO: Handle knn section when semantic text field can be used. |
| 161 | + List<Query> queries = new ArrayList<>(); |
| 162 | + querySection.visit(new QueryVisitor() { |
| 163 | + @Override |
| 164 | + public boolean acceptField(String field) { |
| 165 | + return fieldType.name().equals(field); |
| 166 | + } |
| 167 | + |
| 168 | + @Override |
| 169 | + public void consumeTerms(Query query, Term... terms) { |
| 170 | + super.consumeTerms(query, terms); |
| 171 | + } |
| 172 | + |
| 173 | + @Override |
| 174 | + public void visitLeaf(Query query) { |
| 175 | + if (query instanceof KnnFloatVectorQuery knnQuery) { |
| 176 | + queries.add(fieldType.createExactKnnQuery(VectorData.fromFloats(knnQuery.getTargetCopy()), null)); |
| 177 | + } else if (query instanceof KnnByteVectorQuery knnQuery) { |
| 178 | + queries.add(fieldType.createExactKnnQuery(VectorData.fromBytes(knnQuery.getTargetCopy()), null)); |
| 179 | + } |
| 180 | + } |
| 181 | + }); |
| 182 | + return queries; |
| 183 | + } |
| 184 | + |
| 185 | + private List<Query> extractSparseVectorQueries(SparseVectorFieldType fieldType, Query querySection) { |
| 186 | + List<Query> queries = new ArrayList<>(); |
| 187 | + querySection.visit(new QueryVisitor() { |
| 188 | + @Override |
| 189 | + public boolean acceptField(String field) { |
| 190 | + return fieldType.name().equals(field); |
| 191 | + } |
| 192 | + |
| 193 | + @Override |
| 194 | + public void consumeTerms(Query query, Term... terms) { |
| 195 | + super.consumeTerms(query, terms); |
| 196 | + } |
| 197 | + |
| 198 | + @Override |
| 199 | + public QueryVisitor getSubVisitor(BooleanClause.Occur occur, Query parent) { |
| 200 | + if (parent instanceof SparseVectorQueryWrapper sparseVectorQuery) { |
| 201 | + queries.add(sparseVectorQuery.getTermsQuery()); |
| 202 | + } |
| 203 | + return this; |
| 204 | + } |
| 205 | + }); |
| 206 | + return queries; |
| 207 | + } |
| 208 | +} |
0 commit comments