diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index f3aee46398432..cd68dbcc07d2f 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -208,6 +208,7 @@ import org.elasticsearch.search.aggregations.pipeline.StatsBucketPipelineAggregationBuilder; import org.elasticsearch.search.aggregations.pipeline.SumBucketPipelineAggregationBuilder; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; +import org.elasticsearch.search.diversification.ResultDiversificationRetrieverBuilder; import org.elasticsearch.search.fetch.FetchPhase; import org.elasticsearch.search.fetch.FetchSubPhase; import org.elasticsearch.search.fetch.subphase.ExplainPhase; @@ -1087,6 +1088,9 @@ private void registerRetrieverParsers(List plugins) { registerRetriever(new RetrieverSpec<>(StandardRetrieverBuilder.NAME, StandardRetrieverBuilder::fromXContent)); registerRetriever(new RetrieverSpec<>(KnnRetrieverBuilder.NAME, KnnRetrieverBuilder::fromXContent)); registerRetriever(new RetrieverSpec<>(RescorerRetrieverBuilder.NAME, RescorerRetrieverBuilder::fromXContent)); + registerRetriever( + new RetrieverSpec<>(ResultDiversificationRetrieverBuilder.NAME, ResultDiversificationRetrieverBuilder::fromXContent) + ); registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever); } diff --git a/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversification.java b/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversification.java new file mode 100644 index 0000000000000..fca8f385780c6 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversification.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.diversification; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.vectors.VectorData; + +import java.io.IOException; + +/** + * Base interface for result diversification. + */ +public abstract class ResultDiversification { + + public abstract RankDoc[] diversify(RankDoc[] docs, ResultDiversificationContext diversificationContext) throws IOException; + + protected float getVectorComparisonScore( + VectorSimilarityFunction similarityFunction, + boolean useFloat, + VectorData thisDocVector, + VectorData comparisonVector + ) { + return useFloat + ? similarityFunction.compare(thisDocVector.floatVector(), comparisonVector.floatVector()) + : similarityFunction.compare(thisDocVector.byteVector(), comparisonVector.byteVector()); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationContext.java b/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationContext.java new file mode 100644 index 0000000000000..40eb2dbf0deb3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationContext.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.diversification; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.search.vectors.VectorData; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +public abstract class ResultDiversificationContext { + private final String field; + private final int numCandidates; + private final DenseVectorFieldMapper fieldMapper; + private final IndexVersion indexVersion; + private final VectorData queryVector; + private Map fieldVectors; + + // Field _must_ be a dense_vector type + protected ResultDiversificationContext( + String field, + int numCandidates, + DenseVectorFieldMapper fieldMapper, + IndexVersion indexVersion, + @Nullable VectorData queryVector, + @Nullable Map fieldVectors + ) { + this.field = field; + this.numCandidates = numCandidates; + this.fieldMapper = fieldMapper; + this.indexVersion = indexVersion; + this.queryVector = queryVector; + this.fieldVectors = fieldVectors == null ? new HashMap<>() : fieldVectors; + } + + public String getField() { + return field; + } + + public int getNumCandidates() { + return numCandidates; + } + + public DenseVectorFieldMapper getFieldMapper() { + return fieldMapper; + } + + public DenseVectorFieldMapper.ElementType getElementType() { + return fieldMapper.fieldType().getElementType(); + } + + public IndexVersion getIndexVersion() { + return indexVersion; + } + + public void setFieldVectors(Map fieldVectors) { + this.fieldVectors = fieldVectors; + } + + public VectorData getQueryVector() { + return queryVector; + } + + public VectorData getFieldVector(int docId) { + return fieldVectors.getOrDefault(docId, null); + } + + public Set> getFieldVectorsEntrySet() { + return fieldVectors.entrySet(); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationRetrieverBuilder.java new file mode 100644 index 0000000000000..e65191fd18747 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/diversification/ResultDiversificationRetrieverBuilder.java @@ -0,0 +1,313 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.diversification; + +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.Mapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.diversification.mmr.MMRResultDiversification; +import org.elasticsearch.search.diversification.mmr.MMRResultDiversificationContext; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.vectors.VectorData; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.action.ValidateActions.addValidationError; +import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public final class ResultDiversificationRetrieverBuilder extends CompoundRetrieverBuilder { + + public static final String DIVERSIFICATION_TYPE_MMR = "mmr"; + public static final Float DEFAULT_LAMBDA_VALUE = 0.7f; + + public static final String NAME = "diversify"; + public static final ParseField RETRIEVER_FIELD = new ParseField("retriever"); + public static final ParseField TYPE_FIELD = new ParseField("type"); + public static final ParseField FIELD_FIELD = new ParseField("field"); + public static final ParseField QUERY_FIELD = new ParseField("query_vector"); + public static final ParseField LAMBDA_FIELD = new ParseField("lambda"); + + public static class RankDocWithSearchHit extends RankDoc { + private final SearchHit hit; + + public RankDocWithSearchHit(int doc, float score, int shardIndex, SearchHit hit) { + super(doc, score, shardIndex); + this.hit = hit; + } + + public SearchHit hit() { + return hit; + } + } + + static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, false, args -> { + String diversificationType = (String) args[1]; + String diversificationField = (String) args[2]; + int rankWindowSize = args[3] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[3]; + + @SuppressWarnings("unchecked") + ArrayList queryVectorList = args[4] == null ? null : (ArrayList) args[4]; + float[] queryVector = null; + if (queryVectorList != null) { + queryVector = new float[queryVectorList.size()]; + for (int i = 0; i < queryVectorList.size(); i++) { + queryVector[i] = queryVectorList.get(i); + } + } + + Float lambda = (Float) args[5]; + return new ResultDiversificationRetrieverBuilder( + RetrieverSource.from((RetrieverBuilder) args[0]), + diversificationType, + diversificationField, + rankWindowSize, + queryVector, + lambda + ); + }); + + static { + PARSER.declareNamedObject(constructorArg(), (parser, context, n) -> { + RetrieverBuilder innerRetriever = parser.namedObject(RetrieverBuilder.class, n, context); + context.trackRetrieverUsage(innerRetriever); + return innerRetriever; + }, RETRIEVER_FIELD); + PARSER.declareString(constructorArg(), TYPE_FIELD); + PARSER.declareString(constructorArg(), FIELD_FIELD); + PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); + PARSER.declareFloatArray(optionalConstructorArg(), QUERY_FIELD); + PARSER.declareFloat(optionalConstructorArg(), LAMBDA_FIELD); + RetrieverBuilder.declareBaseParserFields(PARSER); + } + + private final String diversificationType; + private final String diversificationField; + private final float[] queryVector; + private final Float lambda; + private ResultDiversificationContext diversificationContext = null; + + ResultDiversificationRetrieverBuilder( + RetrieverSource innerRetriever, + String diversificationType, + String diversificationField, + int rankWindowSize, + @Nullable float[] queryVector, + @Nullable Float lambda + ) { + super(List.of(innerRetriever), rankWindowSize); + this.diversificationType = diversificationType; + this.diversificationField = diversificationField; + this.queryVector = queryVector; + this.lambda = lambda; + } + + @Override + protected ResultDiversificationRetrieverBuilder clone( + List newChildRetrievers, + List newPreFilterQueryBuilders + ) { + assert newChildRetrievers.size() == 1 : "ResultDiversificationRetrieverBuilder must have a single child retriever"; + return new ResultDiversificationRetrieverBuilder( + newChildRetrievers.getFirst(), + diversificationType, + diversificationField, + rankWindowSize, + queryVector, + lambda + ); + } + + @Override + public ActionRequestValidationException validate( + SearchSourceBuilder source, + ActionRequestValidationException validationException, + boolean isScroll, + boolean allowPartialSearchResults + ) { + // ensure the type is one we know of - at the moment, only "mmr" is valid + if (diversificationType.equals(DIVERSIFICATION_TYPE_MMR) == false) { + validationException = addValidationError( + String.format(Locale.ROOT, "[%s] diversification type must be set to [%s]", getName(), DIVERSIFICATION_TYPE_MMR), + validationException + ); + } + + // if MMR, ensure we have a lambda between 0.0 and 1.0 + if (diversificationType.equals(DIVERSIFICATION_TYPE_MMR) && (lambda == null || lambda < 0.0 || lambda > 1.0)) { + validationException = addValidationError( + String.format( + Locale.ROOT, + "[%s] MMR result diversification must have a [%s] between 0.0 and 1.0", + getName(), + LAMBDA_FIELD.getPreferredName() + ), + validationException + ); + } + + return validationException; + } + + @Override + protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { + IndexVersion indexVersion = ctx.getIndexSettings().getIndexVersionCreated(); + Mapper mapper = ctx.getMappingLookup().getMapper(diversificationField); + if (mapper instanceof DenseVectorFieldMapper == false) { + throw new IllegalArgumentException("[" + diversificationField + "] is not a dense vector field"); + } + + if (diversificationType.equals(DIVERSIFICATION_TYPE_MMR)) { + diversificationContext = new MMRResultDiversificationContext( + diversificationField, + lambda == null ? DEFAULT_LAMBDA_VALUE : lambda, + rankWindowSize, + (DenseVectorFieldMapper) mapper, + indexVersion, + queryVector == null ? null : new VectorData(queryVector), + null + ); + } else { + // should not happen + throw new IllegalArgumentException("Unknown diversification type [" + diversificationType + "]"); + } + + return this; + } + + @Override + protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) { + return super.finalizeSourceBuilder(sourceBuilder).docValueField(diversificationField); + } + + @Override + protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean explain) { + // must have the combined result set + assert rankResults.size() == 1 : "ResultDiversificationRetrieverBuilder must have a single result set"; + assert diversificationContext != null : "diversificationContext should be set before combining results"; + + ScoreDoc[] scoreDocs = rankResults.getFirst(); + if (scoreDocs == null || scoreDocs.length == 0) { + // might happen in the case where we have no results + return new RankDoc[0]; + } + + assert scoreDocs[0] instanceof RankDocWithSearchHit : "expected results to be of type RankDocWithSearchHit"; + + // gather and set the query vectors + // and create our intermediate results set + RankDoc[] results = new RankDoc[scoreDocs.length]; + Map fieldVectors = new HashMap<>(); + for (int i = 0; i < scoreDocs.length; i++) { + RankDocWithSearchHit asRankDoc = (RankDocWithSearchHit) scoreDocs[i]; + results[i] = asRankDoc; + + var field = asRankDoc.hit().getFields().getOrDefault(diversificationField, null); + if (field != null) { + var fieldValue = field.getValue(); + if (fieldValue instanceof float[]) { + fieldVectors.put(asRankDoc.doc, new VectorData((float[]) field.getValue())); + } else if (fieldValue instanceof byte[]) { + fieldVectors.put(asRankDoc.doc, new VectorData((byte[]) field.getValue())); + } + } + } + diversificationContext.setFieldVectors(fieldVectors); + + try { + if (diversificationType.equals(DIVERSIFICATION_TYPE_MMR)) { + MMRResultDiversification diversification = new MMRResultDiversification(); + results = diversification.diversify(results, diversificationContext); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + + return results; + } + + @Override + public String getName() { + return NAME; + } + + public static ResultDiversificationRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) + throws IOException { + try { + return PARSER.apply(parser, context); + } catch (Exception e) { + throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e); + } + } + + @Override + protected void doToXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.getFirst().retriever()); + builder.field(TYPE_FIELD.getPreferredName(), diversificationType); + builder.field(FIELD_FIELD.getPreferredName(), diversificationField); + builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); + + if (queryVector != null) { + builder.array(QUERY_FIELD.getPreferredName(), queryVector); + } + + if (diversificationType.equals(DIVERSIFICATION_TYPE_MMR)) { + if (lambda != null) { + builder.field(LAMBDA_FIELD.getPreferredName(), lambda); + } + } + } + + @Override + protected RankDoc createRankDocFromHit(int docId, SearchHit hit, int shardRequestIndex) { + return new RankDocWithSearchHit(docId, hit.getScore(), shardRequestIndex, hit); + } + + @Override + public boolean doEquals(Object o) { + if (super.doEquals(o) == false) { + return false; + } + + if ((o instanceof ResultDiversificationRetrieverBuilder) == false) { + return false; + } + + ResultDiversificationRetrieverBuilder other = (ResultDiversificationRetrieverBuilder) o; + return this.diversificationType.equals(other.diversificationType) + && this.diversificationField.equals(other.diversificationField) + && Objects.equals(this.lambda, other.lambda) + && Arrays.equals(this.queryVector, other.queryVector); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java b/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java new file mode 100644 index 0000000000000..6b99a4a97f9bf --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversification.java @@ -0,0 +1,183 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.diversification.mmr; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.search.diversification.ResultDiversification; +import org.elasticsearch.search.diversification.ResultDiversificationContext; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.vectors.VectorData; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class MMRResultDiversification extends ResultDiversification { + + @Override + public RankDoc[] diversify(RankDoc[] docs, ResultDiversificationContext diversificationContext) throws IOException { + if (docs == null || docs.length == 0 || ((diversificationContext instanceof MMRResultDiversificationContext) == false)) { + return docs; + } + + MMRResultDiversificationContext context = (MMRResultDiversificationContext) diversificationContext; + + Map docIdIndexMapping = new HashMap<>(); + for (int i = 0; i < docs.length; i++) { + docIdIndexMapping.put(docs[i].doc, i); + } + + VectorSimilarityFunction similarityFunction = DenseVectorFieldMapper.VectorSimilarity.MAX_INNER_PRODUCT.vectorSimilarityFunction( + context.getIndexVersion(), + diversificationContext.getElementType() + ); + + // our chosen DocIDs to keep + List selectedDocIds = new ArrayList<>(); + + // always add the highest scoring doc to the list + int highestScoreDocId = -1; + float highestScore = Float.MIN_VALUE; + for (ScoreDoc doc : docs) { + if (doc.score > highestScore) { + highestScoreDocId = doc.doc; + highestScore = doc.score; + } + } + selectedDocIds.add(highestScoreDocId); + + // test the vector to see if we are using floats or bytes + VectorData firstVec = context.getFieldVector(highestScoreDocId); + boolean useFloat = firstVec.isFloat(); + + // cache the similarity scores for the query vector vs. searchHits + Map querySimilarity = getQuerySimilarityForDocs(docs, similarityFunction, useFloat, context); + + Map> cachedSimilarities = new HashMap<>(); + int numCandidates = context.getNumCandidates(); + + for (int x = 0; x < numCandidates && selectedDocIds.size() < numCandidates && selectedDocIds.size() < docs.length; x++) { + int thisMaxMMRDocId = -1; + float thisMaxMMRScore = Float.NEGATIVE_INFINITY; + for (ScoreDoc doc : docs) { + int docId = doc.doc; + + if (selectedDocIds.contains(docId)) { + continue; + } + + var thisDocVector = context.getFieldVector(docId); + if (thisDocVector == null) { + continue; + } + + var cachedScoresForDoc = cachedSimilarities.getOrDefault(docId, new HashMap<>()); + + // compute MMR scores for remaining searchHits + float highestMMRScore = getHighestScoreForSelectedVectors( + docId, + context, + similarityFunction, + useFloat, + thisDocVector, + cachedScoresForDoc + ); + + // compute MMR + float querySimilarityScore = querySimilarity.getOrDefault(doc.doc, 0.0f); + float mmr = (context.getLambda() * querySimilarityScore) - ((1 - context.getLambda()) * highestMMRScore); + if (mmr > thisMaxMMRScore) { + thisMaxMMRScore = mmr; + thisMaxMMRDocId = docId; + } + + // cache these scores + cachedSimilarities.put(docId, cachedScoresForDoc); + } + + if (thisMaxMMRDocId >= 0) { + selectedDocIds.add(thisMaxMMRDocId); + } + } + + // our return should be only those searchHits that are selected + // and return in the same order as we got them + List returnDocIndices = new ArrayList<>(); + for (Integer docId : selectedDocIds) { + returnDocIndices.add(docIdIndexMapping.get(docId)); + } + returnDocIndices.sort(Integer::compareTo); + + RankDoc[] ret = new RankDoc[returnDocIndices.size()]; + for (int i = 0; i < returnDocIndices.size(); i++) { + ret[i] = docs[returnDocIndices.get(i)]; + } + + return ret; + } + + private float getHighestScoreForSelectedVectors( + int docId, + MMRResultDiversificationContext context, + VectorSimilarityFunction similarityFunction, + boolean useFloat, + VectorData thisDocVector, + Map cachedScoresForDoc + ) { + float highestScore = Float.MIN_VALUE; + for (var vec : context.getFieldVectorsEntrySet()) { + if (vec.getKey().equals(docId)) { + continue; + } + + if (cachedScoresForDoc.containsKey(vec.getKey())) { + float score = cachedScoresForDoc.get(vec.getKey()); + if (score > highestScore) { + highestScore = score; + } + } else { + VectorData comparisonVector = vec.getValue(); + float score = getVectorComparisonScore(similarityFunction, useFloat, thisDocVector, comparisonVector); + cachedScoresForDoc.put(vec.getKey(), score); + if (score > highestScore) { + highestScore = score; + } + } + } + return highestScore; + } + + protected Map getQuerySimilarityForDocs( + ScoreDoc[] docs, + VectorSimilarityFunction similarityFunction, + boolean useFloat, + ResultDiversificationContext context + ) { + Map querySimilarity = new HashMap<>(); + + VectorData queryVector = context.getQueryVector(); + if (queryVector == null) { + return querySimilarity; + } + + for (ScoreDoc doc : docs) { + VectorData vectorData = context.getFieldVector(doc.doc); + if (vectorData != null) { + float querySimilarityScore = getVectorComparisonScore(similarityFunction, useFloat, vectorData, queryVector); + querySimilarity.put(doc.doc, querySimilarityScore); + } + } + return querySimilarity; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationContext.java b/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationContext.java new file mode 100644 index 0000000000000..e4b1623a1e1a1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationContext.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.diversification.mmr; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.search.diversification.ResultDiversificationContext; +import org.elasticsearch.search.vectors.VectorData; + +import java.util.Map; + +public class MMRResultDiversificationContext extends ResultDiversificationContext { + + private final float lambda; + + public MMRResultDiversificationContext( + String field, + float lambda, + int numCandidates, + DenseVectorFieldMapper fieldMapper, + IndexVersion indexVersion, + @Nullable VectorData queryVector, + @Nullable Map fieldVectors + ) { + super(field, numCandidates, fieldMapper, indexVersion, queryVector, fieldVectors); + this.lambda = lambda; + } + + public float getLambda() { + return lambda; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 8042be444292d..661435ca7efe9 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -24,6 +24,7 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.StoredFieldsContext; @@ -341,6 +342,18 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { return this; } + /** + * Overridable method to create the rank doc for the result set. + * + * @param docId the decoded docId + * @param hit the SearchHit object + * @param shardRequestIndex the shared request index + * @return a RankDoc (or subclass) + */ + protected RankDoc createRankDocFromHit(int docId, SearchHit hit, int shardRequestIndex) { + return new RankDoc(docId, hit.getScore(), shardRequestIndex); + } + private RankDoc[] getRankDocs(SearchResponse searchResponse) { int size = searchResponse.getHits().getHits().length; RankDoc[] docs = new RankDoc[size]; @@ -349,7 +362,7 @@ private RankDoc[] getRankDocs(SearchResponse searchResponse) { long sortValue = (long) hit.getRawSortValues()[hit.getRawSortValues().length - 1]; int doc = ShardDocSortField.decodeDoc(sortValue); int shardRequestIndex = ShardDocSortField.decodeShardRequestIndex(sortValue); - docs[i] = new RankDoc(doc, hit.getScore(), shardRequestIndex); + docs[i] = createRankDocFromHit(doc, hit, shardRequestIndex); docs[i].rank = i + 1; } return docs; diff --git a/server/src/test/java/org/elasticsearch/search/diversification/ResultDiversificationRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/diversification/ResultDiversificationRetrieverBuilderParsingTests.java new file mode 100644 index 0000000000000..2d1eb5a2e0158 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/diversification/ResultDiversificationRetrieverBuilderParsingTests.java @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.diversification; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.retriever.TestRetrieverBuilder; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.usage.SearchUsage; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentParser; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static java.util.Collections.emptyList; + +public class ResultDiversificationRetrieverBuilderParsingTests extends AbstractXContentTestCase { + private static List xContentRegistryEntries; + + @BeforeClass + public static void init() { + xContentRegistryEntries = new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents(); + } + + @AfterClass + public static void afterClass() throws Exception { + xContentRegistryEntries = null; + } + + @Override + protected ResultDiversificationRetrieverBuilder createTestInstance() { + int rankWindowSize = randomIntBetween(1, 20); + float[] queryVector = randomBoolean() ? getRandomQueryVector() : null; + Float lambda = randomBoolean() ? randomFloatBetween(0.0f, 1.0f, true) : null; + CompoundRetrieverBuilder.RetrieverSource innerRetriever = new CompoundRetrieverBuilder.RetrieverSource( + TestRetrieverBuilder.createRandomTestRetrieverBuilder(), + null + ); + return new ResultDiversificationRetrieverBuilder(innerRetriever, "mmr", "test_field", rankWindowSize, queryVector, lambda); + } + + @Override + protected ResultDiversificationRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { + return (ResultDiversificationRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( + parser, + new RetrieverParserContext(new SearchUsage(), n -> true) + ); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List entries = new ArrayList<>(xContentRegistryEntries); + entries.add( + new NamedXContentRegistry.Entry( + RetrieverBuilder.class, + TestRetrieverBuilder.TEST_SPEC.getName(), + (p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c), + TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion() + ) + ); + entries.add( + new NamedXContentRegistry.Entry( + RetrieverBuilder.class, + new ParseField(ResultDiversificationRetrieverBuilder.NAME), + (p, c) -> ResultDiversificationRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c) + ) + ); + return new NamedXContentRegistry(entries); + } + + private float[] getRandomQueryVector() { + float[] queryVector = new float[randomIntBetween(5, 256)]; + for (int i = 0; i < queryVector.length; i++) { + queryVector[i] = randomFloatBetween(0.0f, 1.0f, true); + } + return queryVector; + } +} diff --git a/server/src/test/java/org/elasticsearch/search/diversification/ResultDiversificationRetrieverBuilderTests.java b/server/src/test/java/org/elasticsearch/search/diversification/ResultDiversificationRetrieverBuilderTests.java new file mode 100644 index 0000000000000..efe80ba1ae4b6 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/diversification/ResultDiversificationRetrieverBuilderTests.java @@ -0,0 +1,366 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.diversification; + +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.MockResolvedIndices; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.ResolvedIndices; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.Mapping; +import org.elasticsearch.index.mapper.MappingLookup; +import org.elasticsearch.index.mapper.MetadataFieldMapper; +import org.elasticsearch.index.mapper.RootObjectMapper; +import org.elasticsearch.index.mapper.SourceFieldMapper; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.PointInTimeBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.TestRetrieverBuilder; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.RemoteClusterAware; +import org.junit.Assert; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Predicate; + +public class ResultDiversificationRetrieverBuilderTests extends ESTestCase { + + public void testValidate() { + SearchSourceBuilder source = new SearchSourceBuilder(); + + // ensure type is MMR + var notMmrRetriever = new ResultDiversificationRetrieverBuilder( + getInnerRetriever(), + "not_mmr", + "test_field", + 10, + getRandomQueryVector(), + 0.5f + ); + var validationNotMmr = notMmrRetriever.validate(source, null, false, false); + assertEquals(1, validationNotMmr.validationErrors().size()); + assertEquals("[diversify] diversification type must be set to [mmr]", validationNotMmr.validationErrors().getFirst()); + + // ensure lambda is within range and set + var retrieverHighLambda = new ResultDiversificationRetrieverBuilder( + getInnerRetriever(), + "mmr", + "test_field", + 10, + getRandomQueryVector(), + 2.0f + ); + var validationLambda = retrieverHighLambda.validate(source, null, false, false); + assertEquals(1, validationLambda.validationErrors().size()); + assertEquals( + "[diversify] MMR result diversification must have a [lambda] between 0.0 and 1.0", + validationLambda.validationErrors().getFirst() + ); + + var retrieverLowLambda = new ResultDiversificationRetrieverBuilder( + getInnerRetriever(), + "mmr", + "test_field", + 10, + getRandomQueryVector(), + -0.1f + ); + validationLambda = retrieverLowLambda.validate(source, null, false, false); + assertEquals(1, validationLambda.validationErrors().size()); + assertEquals( + "[diversify] MMR result diversification must have a [lambda] between 0.0 and 1.0", + validationLambda.validationErrors().getFirst() + ); + + var retrieverNullLambda = new ResultDiversificationRetrieverBuilder( + getInnerRetriever(), + "mmr", + "test_field", + 10, + getRandomQueryVector(), + null + ); + validationLambda = retrieverNullLambda.validate(source, null, false, false); + assertEquals(1, validationLambda.validationErrors().size()); + assertEquals( + "[diversify] MMR result diversification must have a [lambda] between 0.0 and 1.0", + validationLambda.validationErrors().getFirst() + ); + } + + public void testClone() { + var original = createRandomRetriever(); + var clonedWithSameRetriever = original.clone(List.of(original.innerRetrievers().getFirst()), null); + assertNotSame(original, clonedWithSameRetriever); + assertTrue(original.doEquals(clonedWithSameRetriever)); + + CompoundRetrieverBuilder.RetrieverSource newInnerRetriever = getInnerRetriever(); + var cloned = original.clone(List.of(newInnerRetriever), null); + assertNotSame(original, cloned); + assertFalse(original.doEquals(cloned)); + + // make sure we have to have one and only one new inner retriever + AssertionError exNoRetrievers = Assert.assertThrows(AssertionError.class, () -> original.clone(List.of(), null)); + assertEquals("ResultDiversificationRetrieverBuilder must have a single child retriever", exNoRetrievers.getMessage()); + + AssertionError exTooMany = Assert.assertThrows( + AssertionError.class, + () -> original.clone(List.of(newInnerRetriever, newInnerRetriever), null) + ); + assertEquals("ResultDiversificationRetrieverBuilder must have a single child retriever", exTooMany.getMessage()); + } + + public void testDoRewrite() { + var queryRewriteContext = getQueryRewriteContext(); + var original = createRandomRetriever("dense_vector_field", 256); + var rewritten = original.doRewrite(queryRewriteContext); + assertSame(original, rewritten); + assertCompoundRetriever(original, rewritten); + + // will assert that the rewrite happened without assertion errors + List docs = new ArrayList<>(); + docs.add(new ScoreDoc[] {}); + var result = original.combineInnerRetrieverResults(docs, false); + assertEquals(0, result.length); + } + + public void testMmrResultDiversification() { + var queryRewriteContext = getQueryRewriteContext(); + var retriever = new ResultDiversificationRetrieverBuilder( + getInnerRetriever(), + ResultDiversificationRetrieverBuilder.DIVERSIFICATION_TYPE_MMR, + "dense_vector_field", + 3, + new float[] { 0.5f, 0.2f, 0.4f, 0.4f }, + 0.3f + ); + + // run the rewrite to set the internal diversification context + retriever.doRewrite(queryRewriteContext); + + List docs = new ArrayList<>(); + ScoreDoc[] hits = getTestSearchHits(); + docs.add(hits); + + var result = retriever.combineInnerRetrieverResults(docs, false); + + assertEquals(3, result.length); + assertEquals(1, result[0].doc); + assertEquals(3, result[1].doc); + assertEquals(6, result[2].doc); + + var retrieverWithoutRewrite = new ResultDiversificationRetrieverBuilder( + getInnerRetriever(), + ResultDiversificationRetrieverBuilder.DIVERSIFICATION_TYPE_MMR, + "dense_vector_field", + 3, + new float[] { 0.5f, 0.2f, 0.4f, 0.4f }, + 0.3f + ); + + AssertionError exNoRewrite = assertThrows( + AssertionError.class, + () -> retrieverWithoutRewrite.combineInnerRetrieverResults(docs, false) + ); + assertEquals("diversificationContext should be set before combining results", exNoRewrite.getMessage()); + + retrieverWithoutRewrite.doRewrite(queryRewriteContext); + + List nonProperDocs = new ArrayList<>(); + nonProperDocs.add(new ScoreDoc[] { new ScoreDoc(0, 0) }); + AssertionError exRankDocWithSearchHit = assertThrows( + AssertionError.class, + () -> retrieverWithoutRewrite.combineInnerRetrieverResults(nonProperDocs, false) + ); + assertEquals("expected results to be of type RankDocWithSearchHit", exRankDocWithSearchHit.getMessage()); + + AssertionError exNoDocs = assertThrows( + AssertionError.class, + () -> retrieverWithoutRewrite.combineInnerRetrieverResults(List.of(), false) + ); + assertEquals("ResultDiversificationRetrieverBuilder must have a single result set", exNoDocs.getMessage()); + + docs.add(hits); + AssertionError exMultipleDocs = assertThrows( + AssertionError.class, + () -> retrieverWithoutRewrite.combineInnerRetrieverResults(docs, false) + ); + assertEquals("ResultDiversificationRetrieverBuilder must have a single result set", exNoDocs.getMessage()); + } + + private ScoreDoc[] getTestSearchHits() { + return new ResultDiversificationRetrieverBuilder.RankDocWithSearchHit[] { + getTestSearchHit(1, 2.0f, new float[] { 0.4f, 0.2f, 0.4f, 0.4f }), + getTestSearchHit(2, 1.8f, new float[] { 0.4f, 0.2f, 0.3f, 0.3f }), + getTestSearchHit(3, 1.8f, new float[] { 0.4f, 0.1f, 0.3f, 0.3f }), + getTestSearchHit(4, 1.0f, new float[] { 0.1f, 0.9f, 0.5f, 0.9f }), + getTestSearchHit(5, 0.8f, new float[] { 0.1f, 0.9f, 0.5f, 0.9f }), + getTestSearchHit(6, 0.8f, new float[] { 0.05f, 0.05f, 0.05f, 0.05f }) }; + } + + private ResultDiversificationRetrieverBuilder.RankDocWithSearchHit getTestSearchHit(int docId, float score, float[] value) { + SearchHit hit = new SearchHit(docId); + hit.setDocumentField(new DocumentField("dense_vector_field", List.of(value))); + return new ResultDiversificationRetrieverBuilder.RankDocWithSearchHit(docId, score, 1, hit); + } + + protected void assertCompoundRetriever(ResultDiversificationRetrieverBuilder originalRetriever, RetrieverBuilder rewrittenRetriever) { + assertTrue(rewrittenRetriever instanceof ResultDiversificationRetrieverBuilder); + ResultDiversificationRetrieverBuilder actualRetrieverBuilder = (ResultDiversificationRetrieverBuilder) rewrittenRetriever; + assertEquals(originalRetriever.rankWindowSize(), actualRetrieverBuilder.rankWindowSize()); + } + + private static ResultDiversificationRetrieverBuilder createRandomRetriever() { + return createRandomRetriever(null, null); + } + + private static ResultDiversificationRetrieverBuilder createRandomRetriever( + @Nullable String fieldName, + @Nullable Integer vectorDimensions + ) { + String field = fieldName == null ? "test_field" : fieldName; + int rankWindowSize = randomIntBetween(1, 20); + float[] queryVector = vectorDimensions == null + ? randomBoolean() ? getRandomQueryVector() : null + : getRandomQueryVector(vectorDimensions); + Float lambda = randomFloatBetween(0.0f, 1.0f, true); + CompoundRetrieverBuilder.RetrieverSource innerRetriever = getInnerRetriever(); + return new ResultDiversificationRetrieverBuilder( + innerRetriever, + ResultDiversificationRetrieverBuilder.DIVERSIFICATION_TYPE_MMR, + field, + rankWindowSize, + queryVector, + lambda + ); + } + + private static CompoundRetrieverBuilder.RetrieverSource getInnerRetriever() { + return new CompoundRetrieverBuilder.RetrieverSource(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), null); + } + + private static float[] getRandomQueryVector() { + return getRandomQueryVector(null); + } + + private static float[] getRandomQueryVector(@Nullable Integer vectorDimensions) { + int vectorSize = vectorDimensions == null ? randomIntBetween(5, 256) : vectorDimensions; + float[] queryVector = new float[vectorSize]; + for (int i = 0; i < queryVector.length; i++) { + queryVector[i] = randomFloatBetween(0.0f, 1.0f, true); + } + return queryVector; + } + + private static ResolvedIndices createMockResolvedIndices(Map> localIndexDenseVectorFields) { + Map indexMetadata = new HashMap<>(); + + for (var indexEntry : localIndexDenseVectorFields.entrySet()) { + String indexName = indexEntry.getKey(); + List denseVectorFields = indexEntry.getValue(); + + Index index = new Index(indexName, randomAlphaOfLength(10)); + + IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder(index.getName()) + .settings( + Settings.builder() + .put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()) + .put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID()) + ) + .numberOfShards(1) + .numberOfReplicas(0); + + List denseVectorFieldsList = new ArrayList<>(); + for (String denseVectorField : denseVectorFields) { + denseVectorFieldsList.add( + String.format(Locale.ROOT, "\"%s\": { \"type\": \"dense_vector\", \"dims\": 256 }", denseVectorField) + ); + } + String mapping = String.format(Locale.ROOT, "{ \"properties\": {%s}}", String.join(",", denseVectorFieldsList)); + indexMetadataBuilder.putMapping(mapping); + indexMetadata.put(index, indexMetadataBuilder.build()); + } + + Map remoteIndices = new HashMap<>(); + return new MockResolvedIndices( + remoteIndices, + new OriginalIndices(localIndexDenseVectorFields.keySet().toArray(new String[0]), IndicesOptions.DEFAULT), + indexMetadata + ); + } + + private QueryRewriteContext getQueryRewriteContext() { + final String indexName = "test-index"; + final List testDenseVectorFields = List.of("dense_vector_field"); + final ResolvedIndices resolvedIndices = createMockResolvedIndices(Map.of(indexName, testDenseVectorFields)); + final Index localIndex = resolvedIndices.getConcreteLocalIndices()[0]; + final Predicate nameMatcher = testDenseVectorFields::contains; + final MappingLookup mappingLookup = MappingLookup.fromMapping(getTestMapping()); + + var indexMetadata = IndexMetadata.builder("index") + .settings( + indexSettings(IndexVersion.current(), 1, 1).put( + Settings.builder().put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current()).build() + ) + ) + .build(); + + return new QueryRewriteContext( + parserConfig(), + null, + null, + null, + mappingLookup, + Collections.emptyMap(), + new IndexSettings(indexMetadata, Settings.EMPTY), + TransportVersion.current(), + RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, + localIndex, + nameMatcher, + null, + null, + () -> false, + null, + resolvedIndices, + new PointInTimeBuilder(new BytesArray("pitid")), + null, + null, + false + ); + } + + private Mapping getTestMapping() { + SourceFieldMapper sourceMapper = new SourceFieldMapper.Builder(null, Settings.EMPTY, false, false, false).setSynthetic().build(); + RootObjectMapper root = new RootObjectMapper.Builder("_doc").add( + new DenseVectorFieldMapper.Builder("dense_vector_field", IndexVersion.current(), false, List.of()) + ).build(MapperBuilderContext.root(true, false)); + + return new Mapping(root, new MetadataFieldMapper[] { sourceMapper }, Map.of()); + } +} diff --git a/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java b/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java new file mode 100644 index 0000000000000..03f6fe6f09f60 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/diversification/mmr/MMRResultDiversificationTests.java @@ -0,0 +1,111 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.diversification.mmr; + +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.vectors.VectorData; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class MMRResultDiversificationTests extends ESTestCase { + + public void testMMRDiversification() throws IOException { + final MapperBuilderContext context = MapperBuilderContext.root(false, false); + + DenseVectorFieldMapper mapper = new DenseVectorFieldMapper.Builder("dense_vector_field", IndexVersion.current(), false, List.of()) + .elementType(DenseVectorFieldMapper.ElementType.FLOAT) + .dimensions(4) + .build(context); + + DenseVectorFieldMapper.Builder builder = (DenseVectorFieldMapper.Builder) mapper.getMergeBuilder(); + builder.elementType(DenseVectorFieldMapper.ElementType.FLOAT); + DenseVectorFieldMapper fieldMapper = builder.build(context); + + var queryVectorData = new VectorData(new float[] { 0.5f, 0.2f, 0.4f, 0.4f }); + Map fieldVectors = Map.of( + 1, + new VectorData(new float[] { 0.4f, 0.2f, 0.4f, 0.4f }), + 2, + new VectorData(new float[] { 0.4f, 0.2f, 0.3f, 0.3f }), + 3, + new VectorData(new float[] { 0.4f, 0.1f, 0.3f, 0.3f }), + 4, + new VectorData(new float[] { 0.1f, 0.9f, 0.5f, 0.9f }), + 5, + new VectorData(new float[] { 0.1f, 0.9f, 0.5f, 0.9f }), + 6, + new VectorData(new float[] { 0.05f, 0.05f, 0.05f, 0.05f }) + ); + var diversificationContext = new MMRResultDiversificationContext( + "dense_vector_field", + 0.3f, + 3, + fieldMapper, + IndexVersion.current(), + queryVectorData, + fieldVectors + ); + + RankDoc[] docs = new RankDoc[] { + new RankDoc(1, 2.0f, 1), + new RankDoc(2, 1.8f, 1), + new RankDoc(3, 1.8f, 1), + new RankDoc(4, 1.0f, 1), + new RankDoc(5, 0.8f, 1), + new RankDoc(6, 0.8f, 1) }; + + MMRResultDiversification resultDiversification = new MMRResultDiversification(); + RankDoc[] diversifiedTopDocs = resultDiversification.diversify(docs, diversificationContext); + assertNotSame(docs, diversifiedTopDocs); + + assertEquals(3, diversifiedTopDocs.length); + assertEquals(1, diversifiedTopDocs[0].doc); + assertEquals(3, diversifiedTopDocs[1].doc); + assertEquals(6, diversifiedTopDocs[2].doc); + } + + public void testMMRDiversificationIfNoSearchHits() throws IOException { + final MapperBuilderContext context = MapperBuilderContext.root(false, false); + + DenseVectorFieldMapper mapper = new DenseVectorFieldMapper.Builder("dense_vector_field", IndexVersion.current(), false, List.of()) + .elementType(DenseVectorFieldMapper.ElementType.FLOAT) + .dimensions(4) + .build(context); + + // Change the element type to byte, which is incompatible with int8 HNSW index options + DenseVectorFieldMapper.Builder builder = (DenseVectorFieldMapper.Builder) mapper.getMergeBuilder(); + builder.elementType(DenseVectorFieldMapper.ElementType.FLOAT); + DenseVectorFieldMapper fieldMapper = builder.build(context); + + var queryVectorData = new VectorData(new float[] { 0.5f, 0.2f, 0.4f, 0.4f }); + var diversificationContext = new MMRResultDiversificationContext( + "dense_vector_field", + 0.6f, + 10, + fieldMapper, + IndexVersion.current(), + queryVectorData, + new HashMap<>() + ); + RankDoc[] emptyDocs = new RankDoc[0]; + + MMRResultDiversification resultDiversification = new MMRResultDiversification(); + + assertSame(emptyDocs, resultDiversification.diversify(emptyDocs, diversificationContext)); + assertNull(resultDiversification.diversify(null, diversificationContext)); + } +}