| 
 | 1 | +/*  | 
 | 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one  | 
 | 3 | + * or more contributor license agreements. Licensed under the Elastic License  | 
 | 4 | + * 2.0; you may not use this file except in compliance with the Elastic License  | 
 | 5 | + * 2.0.  | 
 | 6 | + */  | 
 | 7 | + | 
 | 8 | +package org.elasticsearch.xpack.inference.highlight;  | 
 | 9 | + | 
 | 10 | +import org.apache.lucene.index.LeafReader;  | 
 | 11 | +import org.apache.lucene.index.Term;  | 
 | 12 | +import org.apache.lucene.search.BooleanClause;  | 
 | 13 | +import org.apache.lucene.search.BooleanQuery;  | 
 | 14 | +import org.apache.lucene.search.DocIdSetIterator;  | 
 | 15 | +import org.apache.lucene.search.IndexSearcher;  | 
 | 16 | +import org.apache.lucene.search.KnnByteVectorQuery;  | 
 | 17 | +import org.apache.lucene.search.KnnFloatVectorQuery;  | 
 | 18 | +import org.apache.lucene.search.Query;  | 
 | 19 | +import org.apache.lucene.search.QueryVisitor;  | 
 | 20 | +import org.apache.lucene.search.ScoreMode;  | 
 | 21 | +import org.apache.lucene.search.Scorer;  | 
 | 22 | +import org.apache.lucene.search.Weight;  | 
 | 23 | +import org.elasticsearch.common.text.Text;  | 
 | 24 | +import org.elasticsearch.common.xcontent.support.XContentMapValues;  | 
 | 25 | +import org.elasticsearch.index.mapper.MappedFieldType;  | 
 | 26 | +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType;  | 
 | 27 | +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType;  | 
 | 28 | +import org.elasticsearch.index.query.SearchExecutionContext;  | 
 | 29 | +import org.elasticsearch.search.fetch.subphase.highlight.FieldHighlightContext;  | 
 | 30 | +import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;  | 
 | 31 | +import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;  | 
 | 32 | +import org.elasticsearch.search.vectors.VectorData;  | 
 | 33 | +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper;  | 
 | 34 | +import org.elasticsearch.xpack.inference.mapper.SemanticTextField;  | 
 | 35 | +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;  | 
 | 36 | + | 
 | 37 | +import java.io.IOException;  | 
 | 38 | +import java.util.ArrayList;  | 
 | 39 | +import java.util.Comparator;  | 
 | 40 | +import java.util.List;  | 
 | 41 | +import java.util.Locale;  | 
 | 42 | +import java.util.Map;  | 
 | 43 | + | 
 | 44 | +/**  | 
 | 45 | + * A {@link Highlighter} designed for the {@link SemanticTextFieldMapper}.  | 
 | 46 | + * This highlighter extracts semantic queries and evaluates them against each chunk produced by the semantic text field.  | 
 | 47 | + * It returns the top-scoring chunks as snippets, optionally sorted by their scores.  | 
 | 48 | + */  | 
 | 49 | +public class SemanticTextHighlighter implements Highlighter {  | 
 | 50 | +    public static final String NAME = "semantic";  | 
 | 51 | + | 
 | 52 | +    private record OffsetAndScore(int offset, float score) {}  | 
 | 53 | + | 
 | 54 | +    @Override  | 
 | 55 | +    public boolean canHighlight(MappedFieldType fieldType) {  | 
 | 56 | +        if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType) {  | 
 | 57 | +            return true;  | 
 | 58 | +        }  | 
 | 59 | +        return false;  | 
 | 60 | +    }  | 
 | 61 | + | 
 | 62 | +    @Override  | 
 | 63 | +    public HighlightField highlight(FieldHighlightContext fieldContext) throws IOException {  | 
 | 64 | +        SemanticTextFieldMapper.SemanticTextFieldType fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldContext.fieldType;  | 
 | 65 | +        if (fieldType.getEmbeddingsField() == null) {  | 
 | 66 | +            // nothing indexed yet  | 
 | 67 | +            return null;  | 
 | 68 | +        }  | 
 | 69 | + | 
 | 70 | +        final List<Query> queries = switch (fieldType.getModelSettings().taskType()) {  | 
 | 71 | +            case SPARSE_EMBEDDING -> extractSparseVectorQueries(  | 
 | 72 | +                (SparseVectorFieldType) fieldType.getEmbeddingsField().fieldType(),  | 
 | 73 | +                fieldContext.query  | 
 | 74 | +            );  | 
 | 75 | +            case TEXT_EMBEDDING -> extractDenseVectorQueries(  | 
 | 76 | +                (DenseVectorFieldType) fieldType.getEmbeddingsField().fieldType(),  | 
 | 77 | +                fieldContext.query  | 
 | 78 | +            );  | 
 | 79 | +            default -> throw new IllegalStateException(  | 
 | 80 | +                "Wrong task type for a semantic text field, got [" + fieldType.getModelSettings().taskType().name() + "]"  | 
 | 81 | +            );  | 
 | 82 | +        };  | 
 | 83 | +        if (queries.isEmpty()) {  | 
 | 84 | +            // nothing to highlight  | 
 | 85 | +            return null;  | 
 | 86 | +        }  | 
 | 87 | + | 
 | 88 | +        int numberOfFragments = fieldContext.field.fieldOptions().numberOfFragments() <= 0  | 
 | 89 | +            ? 1 // we return the best fragment by default  | 
 | 90 | +            : fieldContext.field.fieldOptions().numberOfFragments();  | 
 | 91 | + | 
 | 92 | +        List<OffsetAndScore> chunks = extractOffsetAndScores(  | 
 | 93 | +            fieldContext.context.getSearchExecutionContext(),  | 
 | 94 | +            fieldContext.hitContext.reader(),  | 
 | 95 | +            fieldType,  | 
 | 96 | +            fieldContext.hitContext.docId(),  | 
 | 97 | +            queries  | 
 | 98 | +        );  | 
 | 99 | +        if (chunks.size() == 0) {  | 
 | 100 | +            return null;  | 
 | 101 | +        }  | 
 | 102 | + | 
 | 103 | +        chunks.sort(Comparator.comparingDouble(OffsetAndScore::score).reversed());  | 
 | 104 | +        int size = Math.min(chunks.size(), numberOfFragments);  | 
 | 105 | +        if (fieldContext.field.fieldOptions().scoreOrdered() == false) {  | 
 | 106 | +            chunks = chunks.subList(0, size);  | 
 | 107 | +            chunks.sort(Comparator.comparingInt(c -> c.offset));  | 
 | 108 | +        }  | 
 | 109 | +        Text[] snippets = new Text[size];  | 
 | 110 | +        List<Map<?, ?>> nestedSources = XContentMapValues.extractNestedSources(  | 
 | 111 | +            fieldType.getChunksField().fullPath(),  | 
 | 112 | +            fieldContext.hitContext.source().source()  | 
 | 113 | +        );  | 
 | 114 | +        for (int i = 0; i < size; i++) {  | 
 | 115 | +            var chunk = chunks.get(i);  | 
 | 116 | +            if (nestedSources.size() <= chunk.offset) {  | 
 | 117 | +                throw new IllegalStateException(  | 
 | 118 | +                    String.format(  | 
 | 119 | +                        Locale.ROOT,  | 
 | 120 | +                        "Invalid content detected for field [%s]: the chunks size is [%d], "  | 
 | 121 | +                            + "but a reference to offset [%d] was found in the result.",  | 
 | 122 | +                        fieldType.name(),  | 
 | 123 | +                        nestedSources.size(),  | 
 | 124 | +                        chunk.offset  | 
 | 125 | +                    )  | 
 | 126 | +                );  | 
 | 127 | +            }  | 
 | 128 | +            String content = (String) nestedSources.get(chunk.offset).get(SemanticTextField.CHUNKED_TEXT_FIELD);  | 
 | 129 | +            if (content == null) {  | 
 | 130 | +                throw new IllegalStateException(  | 
 | 131 | +                    String.format(  | 
 | 132 | +                        Locale.ROOT,  | 
 | 133 | + | 
 | 134 | +                        "Invalid content detected for field [%s]: missing text for the chunk at offset [%d].",  | 
 | 135 | +                        fieldType.name(),  | 
 | 136 | +                        chunk.offset  | 
 | 137 | +                    )  | 
 | 138 | +                );  | 
 | 139 | +            }  | 
 | 140 | +            snippets[i] = new Text(content);  | 
 | 141 | +        }  | 
 | 142 | +        return new HighlightField(fieldContext.fieldName, snippets);  | 
 | 143 | +    }  | 
 | 144 | + | 
 | 145 | +    private List<OffsetAndScore> extractOffsetAndScores(  | 
 | 146 | +        SearchExecutionContext context,  | 
 | 147 | +        LeafReader reader,  | 
 | 148 | +        SemanticTextFieldMapper.SemanticTextFieldType fieldType,  | 
 | 149 | +        int docId,  | 
 | 150 | +        List<Query> leafQueries  | 
 | 151 | +    ) throws IOException {  | 
 | 152 | +        var bitSet = context.bitsetFilter(fieldType.getChunksField().parentTypeFilter()).getBitSet(reader.getContext());  | 
 | 153 | +        int previousParent = docId > 0 ? bitSet.prevSetBit(docId - 1) : -1;  | 
 | 154 | + | 
 | 155 | +        BooleanQuery.Builder bq = new BooleanQuery.Builder().add(fieldType.getChunksField().nestedTypeFilter(), BooleanClause.Occur.FILTER);  | 
 | 156 | +        leafQueries.stream().forEach(q -> bq.add(q, BooleanClause.Occur.SHOULD));  | 
 | 157 | +        Weight weight = new IndexSearcher(reader).createWeight(bq.build(), ScoreMode.COMPLETE, 1);  | 
 | 158 | +        Scorer scorer = weight.scorer(reader.getContext());  | 
 | 159 | +        if (previousParent != -1) {  | 
 | 160 | +            if (scorer.iterator().advance(previousParent) == DocIdSetIterator.NO_MORE_DOCS) {  | 
 | 161 | +                return List.of();  | 
 | 162 | +            }  | 
 | 163 | +        } else if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) {  | 
 | 164 | +            return List.of();  | 
 | 165 | +        }  | 
 | 166 | +        List<OffsetAndScore> results = new ArrayList<>();  | 
 | 167 | +        int offset = 0;  | 
 | 168 | +        while (scorer.docID() < docId) {  | 
 | 169 | +            results.add(new OffsetAndScore(offset++, scorer.score()));  | 
 | 170 | +            if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) {  | 
 | 171 | +                break;  | 
 | 172 | +            }  | 
 | 173 | +        }  | 
 | 174 | +        return results;  | 
 | 175 | +    }  | 
 | 176 | + | 
 | 177 | +    private List<Query> extractDenseVectorQueries(DenseVectorFieldType fieldType, Query querySection) {  | 
 | 178 | +        // TODO: Handle knn section when semantic text field can be used.  | 
 | 179 | +        List<Query> queries = new ArrayList<>();  | 
 | 180 | +        querySection.visit(new QueryVisitor() {  | 
 | 181 | +            @Override  | 
 | 182 | +            public boolean acceptField(String field) {  | 
 | 183 | +                return fieldType.name().equals(field);  | 
 | 184 | +            }  | 
 | 185 | + | 
 | 186 | +            @Override  | 
 | 187 | +            public void consumeTerms(Query query, Term... terms) {  | 
 | 188 | +                super.consumeTerms(query, terms);  | 
 | 189 | +            }  | 
 | 190 | + | 
 | 191 | +            @Override  | 
 | 192 | +            public void visitLeaf(Query query) {  | 
 | 193 | +                if (query instanceof KnnFloatVectorQuery knnQuery) {  | 
 | 194 | +                    queries.add(fieldType.createExactKnnQuery(VectorData.fromFloats(knnQuery.getTargetCopy()), null));  | 
 | 195 | +                } else if (query instanceof KnnByteVectorQuery knnQuery) {  | 
 | 196 | +                    queries.add(fieldType.createExactKnnQuery(VectorData.fromBytes(knnQuery.getTargetCopy()), null));  | 
 | 197 | +                }  | 
 | 198 | +            }  | 
 | 199 | +        });  | 
 | 200 | +        return queries;  | 
 | 201 | +    }  | 
 | 202 | + | 
 | 203 | +    private List<Query> extractSparseVectorQueries(SparseVectorFieldType fieldType, Query querySection) {  | 
 | 204 | +        List<Query> queries = new ArrayList<>();  | 
 | 205 | +        querySection.visit(new QueryVisitor() {  | 
 | 206 | +            @Override  | 
 | 207 | +            public boolean acceptField(String field) {  | 
 | 208 | +                return fieldType.name().equals(field);  | 
 | 209 | +            }  | 
 | 210 | + | 
 | 211 | +            @Override  | 
 | 212 | +            public void consumeTerms(Query query, Term... terms) {  | 
 | 213 | +                super.consumeTerms(query, terms);  | 
 | 214 | +            }  | 
 | 215 | + | 
 | 216 | +            @Override  | 
 | 217 | +            public QueryVisitor getSubVisitor(BooleanClause.Occur occur, Query parent) {  | 
 | 218 | +                if (parent instanceof SparseVectorQueryWrapper sparseVectorQuery) {  | 
 | 219 | +                    queries.add(sparseVectorQuery.getTermsQuery());  | 
 | 220 | +                }  | 
 | 221 | +                return this;  | 
 | 222 | +            }  | 
 | 223 | +        });  | 
 | 224 | +        return queries;  | 
 | 225 | +    }  | 
 | 226 | +}  | 
0 commit comments