Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,15 @@ protected static BreakIterator getBreakIterator(SearchHighlightContext.Field fie
}
}

protected static String convertFieldValue(MappedFieldType type, Object value) {
public static String convertFieldValue(MappedFieldType type, Object value) {
if (value instanceof BytesRef) {
return type.valueForDisplay(value).toString();
} else {
return value.toString();
}
}

protected static String mergeFieldValues(List<Object> fieldValues, char valuesSeparator) {
public static String mergeFieldValues(List<Object> fieldValues, char valuesSeparator) {
// postings highlighter accepts all values in a single string, as offsets etc. need to match with content
// loaded from stored fields, we merge all values using a proper separator
String rawValue = Strings.collectionToDelimitedString(fieldValues, String.valueOf(valuesSeparator));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,31 @@
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.fetch.subphase.highlight.DefaultHighlighter;
import org.elasticsearch.search.fetch.subphase.highlight.FieldHighlightContext;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightUtils;
import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper;
import org.elasticsearch.xpack.inference.mapper.OffsetSourceField;
import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SemanticTextFieldType;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;

import static org.elasticsearch.lucene.search.uhighlight.CustomUnifiedHighlighter.MULTIVAL_SEP_CHAR;

/**
* A {@link Highlighter} designed for the {@link SemanticTextFieldMapper}.
Expand All @@ -49,20 +60,19 @@
public class SemanticTextHighlighter implements Highlighter {
public static final String NAME = "semantic";

private record OffsetAndScore(int offset, float score) {}
private record OffsetAndScore(int index, OffsetSourceFieldMapper.OffsetSource offset, float score) {}

@Override
public boolean canHighlight(MappedFieldType fieldType) {
if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
// TODO: Implement highlighting when using inference metadata fields
return semanticTextFieldType.useLegacyFormat();
}
return false;
return fieldType instanceof SemanticTextFieldType;
}

@Override
public HighlightField highlight(FieldHighlightContext fieldContext) throws IOException {
SemanticTextFieldMapper.SemanticTextFieldType fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldContext.fieldType;
if (canHighlight(fieldContext.fieldType) == false) {
return null;
}
SemanticTextFieldType fieldType = (SemanticTextFieldType) fieldContext.fieldType;
if (fieldType.getEmbeddingsField() == null) {
// nothing indexed yet
return null;
Expand Down Expand Up @@ -105,28 +115,36 @@ public HighlightField highlight(FieldHighlightContext fieldContext) throws IOExc
int size = Math.min(chunks.size(), numberOfFragments);
if (fieldContext.field.fieldOptions().scoreOrdered() == false) {
chunks = chunks.subList(0, size);
chunks.sort(Comparator.comparingInt(c -> c.offset));
chunks.sort(Comparator.comparingInt(c -> c.index));
}
Text[] snippets = new Text[size];
List<Map<?, ?>> nestedSources = XContentMapValues.extractNestedSources(
fieldType.getChunksField().fullPath(),
fieldContext.hitContext.source().source()
);
final Function<OffsetAndScore, String> offsetToContent;
if (fieldType.useLegacyFormat()) {
List<Map<?, ?>> nestedSources = XContentMapValues.extractNestedSources(
fieldType.getChunksField().fullPath(),
fieldContext.hitContext.source().source()
);
offsetToContent = entry -> getContentFromNestedSourcesLegacy(fieldType.name(), entry, nestedSources);
} else {
Map<String, String> fieldToContent = new HashMap<>();
offsetToContent = entry -> {
String content = fieldToContent.computeIfAbsent(entry.offset().field(), key -> {
try {
return extractFieldContent(
fieldContext.context.getSearchExecutionContext(),
fieldContext.hitContext,
entry.offset.field()
);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
});
return content.substring(entry.offset().start(), entry.offset().end());
};
}
for (int i = 0; i < size; i++) {
var chunk = chunks.get(i);
if (nestedSources.size() <= chunk.offset) {
throw new IllegalStateException(
String.format(
Locale.ROOT,
"Invalid content detected for field [%s]: the chunks size is [%d], "
+ "but a reference to offset [%d] was found in the result.",
fieldType.name(),
nestedSources.size(),
chunk.offset
)
);
}
String content = (String) nestedSources.get(chunk.offset).get(SemanticTextField.CHUNKED_TEXT_FIELD);
String content = offsetToContent.apply(chunk);
if (content == null) {
throw new IllegalStateException(
String.format(
Expand All @@ -143,10 +161,43 @@ public HighlightField highlight(FieldHighlightContext fieldContext) throws IOExc
return new HighlightField(fieldContext.fieldName, snippets);
}

private String extractFieldContent(SearchExecutionContext searchContext, FetchSubPhase.HitContext hitContext, String sourceField)
throws IOException {
var sourceFieldType = searchContext.getMappingLookup().getFieldType(sourceField);
if (sourceFieldType == null) {
return null;
}

var values = HighlightUtils.loadFieldValues(sourceFieldType, searchContext, hitContext)
.stream()
.<Object>map((s) -> DefaultHighlighter.convertFieldValue(sourceFieldType, s))
.toList();
if (values.size() == 0) {
return null;
}
return DefaultHighlighter.mergeFieldValues(values, MULTIVAL_SEP_CHAR);
}

private String getContentFromNestedSourcesLegacy(String fieldName, OffsetAndScore cand, List<Map<?, ?>> nestedSources) {
if (nestedSources.size() <= cand.index) {
throw new IllegalStateException(
String.format(
Locale.ROOT,
"Invalid content detected for field [%s]: the chunks size is [%d], "
+ "but a reference to offset [%d] was found in the result.",
fieldName,
nestedSources.size(),
cand.index
)
);
}
return (String) nestedSources.get(cand.index).get(SemanticTextField.CHUNKED_TEXT_FIELD);
}

private List<OffsetAndScore> extractOffsetAndScores(
SearchExecutionContext context,
LeafReader reader,
SemanticTextFieldMapper.SemanticTextFieldType fieldType,
SemanticTextFieldType fieldType,
int docId,
List<Query> leafQueries
) throws IOException {
Expand All @@ -164,10 +215,31 @@ private List<OffsetAndScore> extractOffsetAndScores(
} else if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) {
return List.of();
}

OffsetSourceField.OffsetSourceLoader offsetReader = null;
if (fieldType.useLegacyFormat() == false) {
var terms = reader.terms(fieldType.getOffsetsField().fullPath());
if (terms == null) {
// The field is empty
return List.of();
}
offsetReader = OffsetSourceField.loader(terms);
}

List<OffsetAndScore> results = new ArrayList<>();
int offset = 0;
int index = 0;
while (scorer.docID() < docId) {
results.add(new OffsetAndScore(offset++, scorer.score()));
if (offsetReader != null) {
var offset = offsetReader.advanceTo(scorer.docID());
if (offset == null) {
throw new IllegalStateException(
"Cannot highlight field [" + fieldType.name() + "], missing embeddings for doc [" + docId + "]"
);
}
results.add(new OffsetAndScore(index++, offset, scorer.score()));
} else {
results.add(new OffsetAndScore(index++, null, scorer.score()));
}
if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.elasticsearch.xpack.inference.highlight;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexWriterConfig;
Expand Down Expand Up @@ -51,7 +53,6 @@
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.junit.Before;
import org.mockito.Mockito;

import java.io.IOException;
Expand All @@ -71,31 +72,35 @@ public class SemanticTextHighlighterTests extends MapperServiceTestCase {
private static final String SEMANTIC_FIELD_E5 = "body-e5";
private static final String SEMANTIC_FIELD_ELSER = "body-elser";

private Map<String, Object> queries;
private final boolean useLegacyFormat;
private final Map<String, Object> queries;

@Override
protected Collection<? extends Plugin> getPlugins() {
return List.of(new InferencePlugin(Settings.EMPTY));
public SemanticTextHighlighterTests(boolean useLegacyFormat) throws IOException {
this.useLegacyFormat = useLegacyFormat;
var input = Streams.readFully(SemanticTextHighlighterTests.class.getResourceAsStream("queries.json"));
this.queries = XContentHelper.convertToMap(input, false, XContentType.JSON).v2();
}

@ParametersFactory
public static Iterable<Object[]> parameters() throws Exception {
return List.of(new Object[] { true }, new Object[] { false });
}

@Override
@Before
public void setUp() throws Exception {
super.setUp();
var input = Streams.readFully(SemanticTextHighlighterTests.class.getResourceAsStream("queries.json"));
this.queries = XContentHelper.convertToMap(input, false, XContentType.JSON).v2();
protected Collection<? extends Plugin> getPlugins() {
return List.of(new InferencePlugin(Settings.EMPTY));
}

@SuppressWarnings("unchecked")
public void testDenseVector() throws Exception {
var mapperService = createDefaultMapperService();
var mapperService = createDefaultMapperService(useLegacyFormat);
Map<String, Object> queryMap = (Map<String, Object>) queries.get("dense_vector_1");
float[] vector = readDenseVector(queryMap.get("embeddings"));
var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_E5);
KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(fieldType.getEmbeddingsField().fullPath(), vector, 10, 10, null, null);
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), knnQuery, ScoreMode.Max);
var shardRequest = createShardSearchRequest(nestedQueryBuilder);
var sourceToParse = new SourceToParse("0", readSampleDoc("sample-doc.json.gz"), XContentType.JSON);
var sourceToParse = new SourceToParse("0", readSampleDoc(useLegacyFormat), XContentType.JSON);

String[] expectedScorePassages = ((List<String>) queryMap.get("expected_by_score")).toArray(String[]::new);
for (int i = 0; i < expectedScorePassages.length; i++) {
Expand Down Expand Up @@ -124,7 +129,7 @@ public void testDenseVector() throws Exception {

@SuppressWarnings("unchecked")
public void testSparseVector() throws Exception {
var mapperService = createDefaultMapperService();
var mapperService = createDefaultMapperService(useLegacyFormat);
Map<String, Object> queryMap = (Map<String, Object>) queries.get("sparse_vector_1");
List<WeightedToken> tokens = readSparseVector(queryMap.get("embeddings"));
var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_ELSER);
Expand All @@ -138,7 +143,7 @@ public void testSparseVector() throws Exception {
);
NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), sparseQuery, ScoreMode.Max);
var shardRequest = createShardSearchRequest(nestedQueryBuilder);
var sourceToParse = new SourceToParse("0", readSampleDoc("sample-doc.json.gz"), XContentType.JSON);
var sourceToParse = new SourceToParse("0", readSampleDoc(useLegacyFormat), XContentType.JSON);

String[] expectedScorePassages = ((List<String>) queryMap.get("expected_by_score")).toArray(String[]::new);
for (int i = 0; i < expectedScorePassages.length; i++) {
Expand All @@ -165,9 +170,11 @@ public void testSparseVector() throws Exception {
);
}

private MapperService createDefaultMapperService() throws IOException {
private MapperService createDefaultMapperService(boolean useLegacyFormat) throws IOException {
var mappings = Streams.readFully(SemanticTextHighlighterTests.class.getResourceAsStream("mappings.json"));
var settings = Settings.builder().put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), true).build();
var settings = Settings.builder()
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat)
.build();
return createMapperService(settings, mappings.utf8ToString());
}

Expand Down Expand Up @@ -282,7 +289,8 @@ private ShardSearchRequest createShardSearchRequest(QueryBuilder queryBuilder) {
return new ShardSearchRequest(OriginalIndices.NONE, request, new ShardId("index", "index", 0), 0, 1, AliasFilter.EMPTY, 1, 0, null);
}

private BytesReference readSampleDoc(String fileName) throws IOException {
private BytesReference readSampleDoc(boolean useLegacyFormat) throws IOException {
String fileName = useLegacyFormat ? "sample-doc-legacy.json.gz" : "sample-doc.json.gz";
try (var in = new GZIPInputStream(SemanticTextHighlighterTests.class.getResourceAsStream(fileName))) {
return new BytesArray(new BytesRef(in.readAllBytes()));
}
Expand Down
Binary file not shown.
Loading