diff --git a/docs/changelog/124182.yaml b/docs/changelog/124182.yaml new file mode 100644 index 0000000000000..27c36e96ecd9b --- /dev/null +++ b/docs/changelog/124182.yaml @@ -0,0 +1,5 @@ +pr: 124182 +summary: Add `min_score` support to linear retriever +area: Search +type: enhancement +issues: [] diff --git a/docs/reference/elasticsearch/rest-apis/retrievers.md b/docs/reference/elasticsearch/rest-apis/retrievers.md index 14c413a7832ed..1b157b090537e 100644 --- a/docs/reference/elasticsearch/rest-apis/retrievers.md +++ b/docs/reference/elasticsearch/rest-apis/retrievers.md @@ -269,11 +269,11 @@ Each entry specifies the following parameters: * `weight`:: (Optional, float) - The weight that each score of this retriever’s top docs will be multiplied with. Must be greater or equal to 0. Defaults to 1.0. + The weight that each score of this retriever's top docs will be multiplied with. Must be greater or equal to 0. Defaults to 1.0. * `normalizer`:: (Optional, String) - Specifies how we will normalize the retriever’s scores, before applying the specified `weight`. Available values are: `minmax`, and `none`. Defaults to `none`. + Specifies how we will normalize the retriever's scores, before applying the specified `weight`. Available values are: `minmax`, and `none`. Defaults to `none`. * `none` * `minmax` : A `MinMaxScoreNormalizer` that normalizes scores based on the following formula @@ -288,14 +288,78 @@ See also [this hybrid search example](docs-content://solutions/search/retrievers `rank_window_size` : (Optional, integer) - This value determines the size of the individual result sets per query. A higher value will improve result relevance at the cost of performance. The final ranked result set is pruned down to the search request’s [size](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-size-param). `rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`. Defaults to the `size` parameter. + This value determines the size of the individual result sets per query. A higher value will improve result relevance at the cost of performance. The final ranked result set is pruned down to the search request's [size](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-size-param). `rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`. Defaults to the `size` parameter. + + +`min_score` +: (Optional, float) + + Minimum score threshold for documents to be included in the final result set. Documents with scores below this threshold will be filtered out. Must be greater than or equal to 0. Defaults to 0. `filter` : (Optional, [query object or list of query objects](/reference/query-languages/querydsl.md)) - Applies the specified [boolean query filter](/reference/query-languages/query-dsl-bool-query.md) to all of the specified sub-retrievers, according to each retriever’s specifications. + Applies the specified [boolean query filter](/reference/query-languages/query-dsl-bool-query.md) to all of the specified sub-retrievers, according to each retriever's specifications. + + +### Example: Hybrid search with min_score [linear-retriever-example] + +This example demonstrates how to use the Linear retriever to combine a standard retriever with a kNN retriever, applying weights, normalization, and a minimum score threshold: + +```console +GET /restaurants/_search +{ + "retriever": { + "linear": { <1> + "retrievers": [ <2> + { + "retriever": { <3> + "standard": { + "query": { + "multi_match": { + "query": "Italian cuisine", + "fields": [ + "description", + "cuisine" + ] + } + } + } + }, + "weight": 2.0, <4> + "normalizer": "minmax" <5> + }, + { + "retriever": { <6> + "knn": { + "field": "vector", + "query_vector": [10, 22, 77], + "k": 10, + "num_candidates": 10 + } + }, + "weight": 1.0, <7> + "normalizer": "minmax" <8> + } + ], + "rank_window_size": 50, <9> + "min_score": 1.5 <10> + } + } +} +``` +1. Defines a retriever tree with a Linear retriever. +2. The sub-retrievers array. +3. The first sub-retriever is a `standard` retriever. +4. The weight applied to the scores from the standard retriever (2.0). +5. The normalization method applied to the standard retriever's scores. +6. The second sub-retriever is a `knn` retriever. +7. The weight applied to the scores from the kNN retriever (1.0). +8. The normalization method applied to the kNN retriever's scores. +9. The rank window size for the Linear retriever. +10. The minimum score threshold - documents with a combined score below 1.5 will be filtered out from the final result set. ## RRF Retriever [rrf-retriever] @@ -320,13 +384,13 @@ An [RRF](/reference/elasticsearch/rest-apis/reciprocal-rank-fusion.md) retriever `rank_window_size` : (Optional, integer) - This value determines the size of the individual result sets per query. A higher value will improve result relevance at the cost of performance. The final ranked result set is pruned down to the search request’s [size](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-size-param). `rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`. Defaults to the `size` parameter. + This value determines the size of the individual result sets per query. A higher value will improve result relevance at the cost of performance. The final ranked result set is pruned down to the search request's [size](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search#search-size-param). `rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`. Defaults to the `size` parameter. `filter` : (Optional, [query object or list of query objects](/reference/query-languages/querydsl.md)) - Applies the specified [boolean query filter](/reference/query-languages/query-dsl-bool-query.md) to all of the specified sub-retrievers, according to each retriever’s specifications. + Applies the specified [boolean query filter](/reference/query-languages/query-dsl-bool-query.md) to all of the specified sub-retrievers, according to each retriever's specifications. @@ -435,12 +499,12 @@ For compound retrievers like `rrf`, the `window_size` parameter defines the tota When using the `rescorer`, an error is returned if the following conditions are not met: -* The minimum configured rescore’s `window_size` is: +* The minimum configured rescore's `window_size` is: * Greater than or equal to the `size` of the parent retriever for nested `rescorer` setups. * Greater than or equal to the `size` of the search request when used as the primary retriever in the tree. -* And the maximum rescore’s `window_size` is: +* And the maximum rescore's `window_size` is: * Smaller than or equal to the `size` or `rank_window_size` of the child retriever. @@ -564,7 +628,7 @@ To use `text_similarity_reranker` you must first set up an inference endpoint fo You have the following options: -* Use the the built-in [Elastic Rerank](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) cross-encoder model via the inference API’s {{es}} service. +* Use the the built-in [Elastic Rerank](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) cross-encoder model via the inference API's {{es}} service. * Use the [Cohere Rerank inference endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) with the `rerank` task type. * Use the [Google Vertex AI inference endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) with the `rerank` task type. * Upload a model to {{es}} with [Eland](eland://reference/machine-learning.md#ml-nlp-pytorch) using the `text_similarity` NLP task type. diff --git a/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java index 524310c547597..c7fe110f2a905 100644 --- a/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java @@ -29,14 +29,36 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder c.readNamedWriteable(QueryBuilder.class), QueryBuilder[]::new); this.onlyRankDocs = in.readBoolean(); + this.minScore = in.readFloat(); + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_17_0)) { + this.countFilteredHits = in.readBoolean(); + } else { + this.countFilteredHits = false; + } } else { this.queryBuilders = null; this.onlyRankDocs = false; + this.minScore = DEFAULT_MIN_SCORE; + this.countFilteredHits = false; } } @@ -70,7 +100,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws changed |= newQueryBuilders[i] != queryBuilders[i]; } if (changed) { - RankDocsQueryBuilder clone = new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs); + RankDocsQueryBuilder clone = new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs, minScore); clone.queryName(queryName()); return clone; } @@ -88,6 +118,10 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalArray(StreamOutput::writeNamedWriteable, queryBuilders); out.writeBoolean(onlyRankDocs); + out.writeFloat(minScore); + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_17_0)) { + out.writeBoolean(countFilteredHits); + } } } @@ -115,7 +149,12 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { queries = new Query[0]; queryNames = Strings.EMPTY_ARRAY; } - return new RankDocsQuery(reader, shardRankDocs, queries, queryNames, onlyRankDocs); + + RankDocsQuery query = new RankDocsQuery(reader, shardRankDocs, queries, queryNames, onlyRankDocs, minScore); + if (countFilteredHits) { + query.setCountFilteredHits(true); + } + return query; } @Override @@ -135,16 +174,31 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep protected boolean doEquals(RankDocsQueryBuilder other) { return Arrays.equals(rankDocs, other.rankDocs) && Arrays.equals(queryBuilders, other.queryBuilders) - && onlyRankDocs == other.onlyRankDocs; + && onlyRankDocs == other.onlyRankDocs + && minScore == other.minScore + && countFilteredHits == other.countFilteredHits; } @Override protected int doHashCode() { - return Objects.hash(Arrays.hashCode(rankDocs), Arrays.hashCode(queryBuilders), onlyRankDocs); + return Objects.hash(Arrays.hashCode(rankDocs), Arrays.hashCode(queryBuilders), onlyRankDocs, minScore, countFilteredHits); } @Override public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_16_0; } + + /** + * Sets whether this query should count only documents that pass the min_score filter. + * When true, the total hits count will reflect the number of documents meeting the minimum score threshold. + * When false (default), the total hits count will include all matching documents regardless of score. + * + * @param countFilteredHits true to count only documents passing min_score, false to count all matches + * @return this builder + */ + public RankDocsQueryBuilder setCountFilteredHits(boolean countFilteredHits) { + this.countFilteredHits = countFilteredHits; + return this; + } } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 0bb5fd849bbcf..045781be73dea 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -198,6 +198,7 @@ public void onFailure(Exception e) { results::get ); rankDocsRetrieverBuilder.retrieverName(retrieverName()); + rankDocsRetrieverBuilder.minScore = minScore; return rankDocsRetrieverBuilder; } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java index a77f5327fbc26..81aeabef5e813 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java @@ -23,6 +23,8 @@ import java.util.Objects; import java.util.function.Supplier; +import static org.elasticsearch.index.query.RankDocsQueryBuilder.DEFAULT_MIN_SCORE; + /** * An {@link RetrieverBuilder} that is used to retrieve documents based on the rank of the documents. */ @@ -93,7 +95,8 @@ public QueryBuilder explainQuery() { var explainQuery = new RankDocsQueryBuilder( rankDocs.get(), sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new), - true + true, + this.minScore() != null ? this.minScore() : DEFAULT_MIN_SCORE ); explainQuery.queryName(retrieverName()); return explainQuery; @@ -105,38 +108,98 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder // if we have aggregations we need to compute them based on all doc matches, not just the top hits // similarly, for profile and explain we re-run all parent queries to get all needed information RankDoc[] rankDocResults = rankDocs.get(); + float effectiveMinScore = getEffectiveMinScore(); + + System.out.println( + "DEBUG: RankDocsRetrieverBuilder - extractToSearchSourceBuilder with " + + (rankDocResults != null ? rankDocResults.length : 0) + + " rank results" + ); + System.out.println("DEBUG: RankDocsRetrieverBuilder - minScore=" + minScore() + ", effective minScore=" + effectiveMinScore); + if (hasAggregations(searchSourceBuilder) || isExplainRequest(searchSourceBuilder) || isProfileRequest(searchSourceBuilder) || shouldTrackTotalHits(searchSourceBuilder)) { + System.out.println( + "DEBUG: RankDocsRetrieverBuilder - Building with explainQuery=" + + isExplainRequest(searchSourceBuilder) + + ", hasAggs=" + + hasAggregations(searchSourceBuilder) + + ", isProfile=" + + isProfileRequest(searchSourceBuilder) + + ", shouldTrackTotalHits=" + + shouldTrackTotalHits(searchSourceBuilder) + ); + if (false == isExplainRequest(searchSourceBuilder)) { rankQuery = new RankDocsQueryBuilder( rankDocResults, sources.stream().map(RetrieverBuilder::topDocsQuery).toArray(QueryBuilder[]::new), - false + false, + effectiveMinScore ); } else { rankQuery = new RankDocsQueryBuilder( rankDocResults, sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new), - false + false, + effectiveMinScore ); } } else { - rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false); + System.out.println("DEBUG: RankDocsRetrieverBuilder - Building with simplified query"); + rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false, effectiveMinScore); } + + System.out.println("DEBUG: RankDocsRetrieverBuilder - Created rankQuery with minScore=" + effectiveMinScore); rankQuery.queryName(retrieverName()); // ignore prefilters of this level, they were already propagated to children searchSourceBuilder.query(rankQuery); if (searchSourceBuilder.size() < 0) { searchSourceBuilder.size(rankWindowSize); } - if (sourceHasMinScore()) { - searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore()); + + // Set track total hits to equal the number of results, ensuring the correct count is returned + boolean emptyResults = rankDocResults.length == 0; + boolean shouldTrack = shouldTrackTotalHits(searchSourceBuilder); + + if (shouldTrack) { + int hitsToTrack = emptyResults ? Integer.MAX_VALUE : rankDocResults.length; + System.out.println("DEBUG: RankDocsRetrieverBuilder - Setting trackTotalHitsUpTo to " + hitsToTrack); + searchSourceBuilder.trackTotalHitsUpTo(hitsToTrack); + } + + // Always set minScore if it's meaningful (greater than default) + boolean hasSignificantMinScore = effectiveMinScore > DEFAULT_MIN_SCORE; + + System.out.println( + "DEBUG: RankDocsRetrieverBuilder - sourceHasMinScore=" + + sourceHasMinScore() + + ", effectiveMinScore=" + + effectiveMinScore + + ", hasSignificantMinScore=" + + hasSignificantMinScore + ); + + if (hasSignificantMinScore) { + // Set minScore on the search source builder - this ensures filtering happens + searchSourceBuilder.minScore(effectiveMinScore); } + if (searchSourceBuilder.size() + searchSourceBuilder.from() > rankDocResults.length) { searchSourceBuilder.size(Math.max(0, rankDocResults.length - searchSourceBuilder.from())); } + + System.out.println( + "DEBUG: RankDocsRetrieverBuilder - Final searchSourceBuilder: " + + "size=" + + searchSourceBuilder.size() + + ", minScore=" + + searchSourceBuilder.minScore() + + ", trackTotalHitsUpTo=" + + searchSourceBuilder.trackTotalHitsUpTo() + ); } private boolean hasAggregations(SearchSourceBuilder searchSourceBuilder) { @@ -152,7 +215,55 @@ private boolean isProfileRequest(SearchSourceBuilder searchSourceBuilder) { } private boolean shouldTrackTotalHits(SearchSourceBuilder searchSourceBuilder) { - return searchSourceBuilder.trackTotalHitsUpTo() == null || searchSourceBuilder.trackTotalHitsUpTo() > rankDocs.get().length; + // Always track total hits if minScore is being used, since we need to maintain the filtered count + if (minScore() != null && minScore() > DEFAULT_MIN_SCORE) { + return true; + } + + // Check sources for minScore - if any have a significant minScore, we need to track hits + for (RetrieverBuilder source : sources) { + Float sourceMinScore = source.minScore(); + if (sourceMinScore != null && sourceMinScore > DEFAULT_MIN_SCORE) { + return true; + } + } + + // Otherwise use default behavior + return searchSourceBuilder.trackTotalHitsUpTo() == null + || (rankDocs.get() != null && searchSourceBuilder.trackTotalHitsUpTo() > rankDocs.get().length); + } + + /** + * Gets the effective minimum score, either from this builder or from one of its sources. + * If no minimum score is set, returns the default minimum score. + */ + private float getEffectiveMinScore() { + System.out.println("DEBUG: RankDocsRetrieverBuilder.getEffectiveMinScore() - this.minScore=" + minScore); + + if (minScore != null) { + System.out.println("DEBUG: RankDocsRetrieverBuilder.getEffectiveMinScore() - using this.minScore=" + minScore); + return minScore; + } + + // Check if any of the sources have a minScore + System.out.println("DEBUG: RankDocsRetrieverBuilder.getEffectiveMinScore() - checking " + sources.size() + " sources"); + for (RetrieverBuilder source : sources) { + Float sourceMinScore = source.minScore(); + System.out.println( + "DEBUG: RankDocsRetrieverBuilder.getEffectiveMinScore() - source minScore=" + + sourceMinScore + + " for source " + + source.getClass().getSimpleName() + ); + + if (sourceMinScore != null && sourceMinScore > DEFAULT_MIN_SCORE) { + System.out.println("DEBUG: RankDocsRetrieverBuilder.getEffectiveMinScore() - using source minScore=" + sourceMinScore); + return sourceMinScore; + } + } + + System.out.println("DEBUG: RankDocsRetrieverBuilder.getEffectiveMinScore() - using DEFAULT_MIN_SCORE=" + DEFAULT_MIN_SCORE); + return DEFAULT_MIN_SCORE; } @Override diff --git a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java index 5920567646030..ed4fd3092d353 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java @@ -24,6 +24,7 @@ import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; +import org.elasticsearch.common.lucene.search.function.MinScoreScorer; import org.elasticsearch.search.rank.RankDoc; import java.io.IOException; @@ -32,6 +33,7 @@ import java.util.Objects; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.elasticsearch.index.query.RankDocsQueryBuilder.DEFAULT_MIN_SCORE; /** * A {@code RankDocsQuery} returns the top k documents in the order specified by the global doc IDs. @@ -49,14 +51,16 @@ public static class TopQuery extends Query { private final String[] queryNames; private final int[] segmentStarts; private final Object contextIdentity; + private final float minScore; - TopQuery(RankDoc[] docs, Query[] sources, String[] queryNames, int[] segmentStarts, Object contextIdentity) { + TopQuery(RankDoc[] docs, Query[] sources, String[] queryNames, int[] segmentStarts, Object contextIdentity, float minScore) { assert sources.length == queryNames.length; this.docs = docs; this.sources = sources; this.queryNames = queryNames; this.segmentStarts = segmentStarts; this.contextIdentity = contextIdentity; + this.minScore = minScore; for (RankDoc doc : docs) { if (false == doc.score >= 0) { throw new IllegalArgumentException("RankDoc scores must be positive values. Missing a normalization step?"); @@ -76,7 +80,7 @@ public Query rewrite(IndexSearcher searcher) throws IOException { changed |= newSources[i] != sources[i]; } if (changed) { - return new TopQuery(docs, newSources, queryNames, segmentStarts, contextIdentity); + return new TopQuery(docs, newSources, queryNames, segmentStarts, contextIdentity, minScore); } return this; } @@ -164,12 +168,12 @@ public float getMaxScore(int docId) { } @Override - public float score() { - // We could still end up with a valid 0 score for a RankDoc - // so here we want to differentiate between this and all the tailQuery matches - // that would also produce a 0 score due to filtering, by setting the score to `Float.MIN_VALUE` instead for - // RankDoc matches. - return Math.max(docs[upTo].score, Float.MIN_VALUE); + public float score() throws IOException { + // We need to handle scores of exactly 0 specially: + // Even when a document legitimately has a score of 0, we replace it with DEFAULT_MIN_SCORE + // to differentiate it from filtered tailQuematches that would also produce a 0 score. + float docScore = docs[upTo].score; + return docScore == 0 ? DEFAULT_MIN_SCORE : docScore; } @Override @@ -234,6 +238,21 @@ public int hashCode() { // RankDocs provided. This query does not contribute to scoring, as it is set as filter when creating the weight private final Query tailQuery; private final boolean onlyRankDocs; + private final float minScore; + private boolean countFilteredHits = false; + + /** + * Sets whether this query should count only documents that pass the min_score filter. + * When true, the total hits count will reflect the number of documents meeting the minimum score threshold. + * When false (default), the total hits count will include all matching documents regardless of score. + * + * @param countFilteredHits true to count only documents passing min_score, false to count all matches + * @return this query + */ + public RankDocsQuery setCountFilteredHits(boolean countFilteredHits) { + this.countFilteredHits = countFilteredHits; + return this; + } /** * Creates a {@code RankDocsQuery} based on the provided docs. @@ -242,14 +261,26 @@ public int hashCode() { * @param sources The original queries that were used to compute the top documents * @param queryNames The names (if present) of the original retrievers * @param onlyRankDocs Whether the query should only match the provided rank docs + * @param minScore The minimum score threshold for documents to be included in total hits. + * This can be set to any value including 0.0f to filter out documents with scores below the threshold. + * Note: This is separate from the special handling of 0 scores. Documents with a score of exactly 0 + * will always be assigned DEFAULT_MIN_SCORE internally to differentiate them from filtered matches, + * regardless of the minScore value. */ - public RankDocsQuery(IndexReader reader, RankDoc[] rankDocs, Query[] sources, String[] queryNames, boolean onlyRankDocs) { + public RankDocsQuery( + IndexReader reader, + RankDoc[] rankDocs, + Query[] sources, + String[] queryNames, + boolean onlyRankDocs, + float minScore + ) { assert sources.length == queryNames.length; // clone to avoid side-effect after sorting this.docs = rankDocs.clone(); // sort rank docs by doc id Arrays.sort(docs, Comparator.comparingInt(a -> a.doc)); - this.topQuery = new TopQuery(docs, sources, queryNames, findSegmentStarts(reader, docs), reader.getContext().id()); + this.topQuery = new TopQuery(docs, sources, queryNames, findSegmentStarts(reader, docs), reader.getContext().id(), minScore); if (sources.length > 0 && false == onlyRankDocs) { var bq = new BooleanQuery.Builder(); for (var source : sources) { @@ -260,13 +291,17 @@ public RankDocsQuery(IndexReader reader, RankDoc[] rankDocs, Query[] sources, St this.tailQuery = null; } this.onlyRankDocs = onlyRankDocs; + this.minScore = minScore; + this.countFilteredHits = false; } - private RankDocsQuery(RankDoc[] docs, Query topQuery, Query tailQuery, boolean onlyRankDocs) { + private RankDocsQuery(RankDoc[] docs, Query topQuery, Query tailQuery, boolean onlyRankDocs, float minScore) { this.docs = docs; this.topQuery = topQuery; this.tailQuery = tailQuery; this.onlyRankDocs = onlyRankDocs; + this.minScore = minScore; + this.countFilteredHits = false; } private static int binarySearch(RankDoc[] docs, int fromIndex, int toIndex, int key) { @@ -310,7 +345,11 @@ public Query rewrite(IndexSearcher searcher) throws IOException { if (tailRewrite != tailQuery) { hasChanged = true; } - return hasChanged ? new RankDocsQuery(docs, topRewrite, tailRewrite, onlyRankDocs) : this; + RankDocsQuery rewritten = hasChanged ? new RankDocsQuery(docs, topRewrite, tailRewrite, onlyRankDocs, minScore) : this; + if (hasChanged && countFilteredHits) { + rewritten.setCountFilteredHits(true); + } + return rewritten; } @Override @@ -326,7 +365,29 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo return new Weight(this) { @Override public int count(LeafReaderContext context) throws IOException { - return combinedWeight.count(context); + // When minScore is set to a value higher than DEFAULT_MIN_SCORE, filter docs by score + if ((minScore > DEFAULT_MIN_SCORE) && countFilteredHits) { + System.out.println( + "DEBUG: RankDocsQuery - count with minScore=" + + minScore + + " > DEFAULT_MIN_SCORE=" + + DEFAULT_MIN_SCORE + + ", countFilteredHits=" + + countFilteredHits + ); + int count = 0; + for (RankDoc doc : docs) { + if (doc.score >= minScore && doc.doc >= context.docBase && doc.doc < context.docBase + context.reader().maxDoc()) { + count++; + } + } + System.out.println("DEBUG: RankDocsQuery - filtered count=" + count + " from " + docs.length + " total docs"); + return count; + } + + int combinedCount = combinedWeight.count(context); + System.out.println("DEBUG: RankDocsQuery - using combined weight count=" + combinedCount); + return combinedCount; } @Override @@ -346,7 +407,24 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { - return combinedWeight.scorerSupplier(context); + ScorerSupplier supplier = combinedWeight.scorerSupplier(context); + if (supplier == null) { + return null; + } + if (minScore <= DEFAULT_MIN_SCORE) { + return supplier; + } + return new ScorerSupplier() { + @Override + public Scorer get(long leadCost) throws IOException { + return new MinScoreScorer(supplier.get(leadCost), minScore); + } + + @Override + public long cost() { + return supplier.cost(); + } + }; } }; } @@ -370,11 +448,15 @@ public boolean equals(Object obj) { return false; } RankDocsQuery other = (RankDocsQuery) obj; - return Objects.equals(topQuery, other.topQuery) && Objects.equals(tailQuery, other.tailQuery) && onlyRankDocs == other.onlyRankDocs; + return Objects.equals(topQuery, other.topQuery) + && Objects.equals(tailQuery, other.tailQuery) + && onlyRankDocs == other.onlyRankDocs + && minScore == other.minScore + && countFilteredHits == other.countFilteredHits; } @Override public int hashCode() { - return Objects.hash(classHash(), topQuery, tailQuery, onlyRankDocs); + return Objects.hash(classHash(), topQuery, tailQuery, onlyRankDocs, minScore, countFilteredHits); } } diff --git a/server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java index 9f1d2fbfdefff..3bab061fb7662 100644 --- a/server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java @@ -17,6 +17,7 @@ import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopScoreDocCollectorManager; @@ -30,6 +31,7 @@ import java.util.Arrays; import java.util.Random; +import static org.elasticsearch.index.query.RankDocsQueryBuilder.DEFAULT_MIN_SCORE; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.lessThanOrEqualTo; @@ -50,7 +52,7 @@ private RankDoc[] generateRandomRankDocs() { @Override protected RankDocsQueryBuilder doCreateTestQueryBuilder() { RankDoc[] rankDocs = generateRandomRankDocs(); - return new RankDocsQueryBuilder(rankDocs, null, false); + return new RankDocsQueryBuilder(rankDocs, null, false, DEFAULT_MIN_SCORE); } @Override @@ -60,6 +62,23 @@ protected void doAssertLuceneQuery(RankDocsQueryBuilder queryBuilder, Query quer assertArrayEquals(queryBuilder.rankDocs(), rankDocsQuery.rankDocs()); } + protected Query createTestQuery() throws IOException { + return createRandomQuery().toQuery(createSearchExecutionContext()); + } + + private RankDocsQueryBuilder createQueryBuilder() { + return createRandomQuery(); + } + + private RankDocsQueryBuilder createRandomQuery() { + RankDoc[] rankDocs = new RankDoc[randomIntBetween(1, 5)]; + for (int i = 0; i < rankDocs.length; i++) { + rankDocs[i] = new RankDoc(randomInt(), randomFloat(), randomIntBetween(0, 2)); + } + float minScore = randomBoolean() ? DEFAULT_MIN_SCORE : randomFloat(); + return new RankDocsQueryBuilder(rankDocs, null, randomBoolean(), minScore); + } + /** * Overridden to ensure that {@link SearchExecutionContext} has a non-null {@link IndexReader} */ @@ -151,7 +170,8 @@ public void testRankDocsQueryEarlyTerminate() throws IOException { rankDocs, new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) }, new String[1], - false + false, + DEFAULT_MIN_SCORE ); var topDocsManager = new TopScoreDocCollectorManager(topSize, null, totalHitsThreshold); var col = searcher.search(q, topDocsManager); @@ -172,7 +192,8 @@ public void testRankDocsQueryEarlyTerminate() throws IOException { rankDocs, new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) }, new String[1], - false + false, + DEFAULT_MIN_SCORE ); var topDocsManager = new TopScoreDocCollectorManager(topSize, null, Integer.MAX_VALUE); var col = searcher.search(q, topDocsManager); @@ -187,7 +208,8 @@ public void testRankDocsQueryEarlyTerminate() throws IOException { rankDocs, new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) }, new String[1], - true + true, + DEFAULT_MIN_SCORE ); var topDocsManager = new TopScoreDocCollectorManager(topSize, null, Integer.MAX_VALUE); var col = searcher.search(q, topDocsManager); @@ -204,7 +226,8 @@ public void testRankDocsQueryEarlyTerminate() throws IOException { singleRankDoc, new Query[] { NumericDocValuesField.newSlowExactQuery("active", 1) }, new String[1], - false + false, + DEFAULT_MIN_SCORE ); var topDocsManager = new TopScoreDocCollectorManager(1, null, 0); var col = searcher.search(q, topDocsManager); @@ -257,10 +280,34 @@ public void shouldThrowForNegativeScores() throws IOException { iw.addDocument(new Document()); try (IndexReader reader = iw.getReader()) { SearchExecutionContext context = createSearchExecutionContext(newSearcher(reader)); - RankDocsQueryBuilder queryBuilder = new RankDocsQueryBuilder(new RankDoc[] { new RankDoc(0, -1.0f, 0) }, null, false); + RankDocsQueryBuilder queryBuilder = new RankDocsQueryBuilder( + new RankDoc[] { new RankDoc(0, -1.0f, 0) }, + null, + false, + DEFAULT_MIN_SCORE + ); IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> queryBuilder.doToQuery(context)); assertEquals("RankDoc scores must be positive values. Missing a normalization step?", ex.getMessage()); } } } + + public void testCreateQuery() throws IOException { + try (Directory directory = newDirectory(); RandomIndexWriter iw = new RandomIndexWriter(random(), directory)) { + iw.addDocument(new Document()); + try (IndexReader reader = iw.getReader()) { + RankDoc[] rankDocs = new RankDoc[] { new RankDoc(0, randomFloat(), 0) }; + RankDocsQuery q = new RankDocsQuery( + reader, + rankDocs, + new Query[] { new MatchAllDocsQuery() }, + new String[] { "test" }, + false, + DEFAULT_MIN_SCORE + ); + assertNotNull(q); + assertArrayEquals(rankDocs, q.rankDocs()); + } + } + } } diff --git a/server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java b/server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java index 8cc40570ab4bb..19fa11d0d700c 100644 --- a/server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Set; +import static org.elasticsearch.index.query.RankDocsQueryBuilder.DEFAULT_MIN_SCORE; import static org.hamcrest.Matchers.equalTo; public abstract class AbstractRankDocWireSerializingTestCase extends AbstractWireSerializingTestCase { @@ -50,7 +51,12 @@ public void testRankDocSerialization() throws IOException { for (int i = 0; i < totalDocs; i++) { docs.add(createTestRankDoc()); } - RankDocsQueryBuilder rankDocsQueryBuilder = new RankDocsQueryBuilder(docs.toArray((T[]) new RankDoc[0]), null, randomBoolean()); + RankDocsQueryBuilder rankDocsQueryBuilder = new RankDocsQueryBuilder( + docs.toArray((T[]) new RankDoc[0]), + null, + randomBoolean(), + DEFAULT_MIN_SCORE + ); RankDocsQueryBuilder copy = (RankDocsQueryBuilder) copyNamedWriteable(rankDocsQueryBuilder, writableRegistry(), QueryBuilder.class); assertThat(rankDocsQueryBuilder, equalTo(copy)); } diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java index f98231a647470..a8e3cf2142d1b 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java @@ -12,10 +12,19 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.ClosePointInTimeRequest; +import org.elasticsearch.action.search.ClosePointInTimeResponse; +import org.elasticsearch.action.search.OpenPointInTimeRequest; +import org.elasticsearch.action.search.OpenPointInTimeResponse; import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.TransportClosePointInTimeAction; +import org.elasticsearch.action.search.TransportOpenPointInTimeAction; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.query.InnerHitBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -23,6 +32,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; @@ -39,12 +49,15 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin; +import org.junit.After; import org.junit.Before; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; @@ -53,6 +66,7 @@ import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.instanceOf; @ESIntegTestCase.ClusterScope(minNumDataNodes = 2) @@ -64,6 +78,8 @@ public class LinearRetrieverIT extends ESIntegTestCase { protected static final String VECTOR_FIELD = "vector"; protected static final String TOPIC_FIELD = "topic"; + private BytesReference pitId; + @Override protected Collection> nodePlugins() { return List.of(RRFRankPlugin.class); @@ -72,6 +88,39 @@ protected Collection> nodePlugins() { @Before public void setup() throws Exception { setupIndex(); + // Create a point in time that will be used across all tests + OpenPointInTimeRequest openRequest = new OpenPointInTimeRequest(INDEX).keepAlive(TimeValue.timeValueMinutes(2)); + OpenPointInTimeResponse openResp = client().execute(TransportOpenPointInTimeAction.TYPE, openRequest).actionGet(); + pitId = openResp.getPointInTimeId(); + } + + @After + public void cleanup() { + if (pitId != null) { + try { + // Use actionGet with timeout to ensure this completes + ClosePointInTimeResponse closeResponse = client().execute( + TransportClosePointInTimeAction.TYPE, + new ClosePointInTimeRequest(pitId) + ).actionGet(30, TimeUnit.SECONDS); + + logger.info("Closed PIT successfully"); + + // Force release references + pitId = null; + + // Give resources a moment to be properly released + Thread.sleep(100); + } catch (Exception e) { + logger.error("Error closing point in time", e); + } + } + } + + protected SearchRequestBuilder prepareSearchWithPIT(SearchSourceBuilder source) { + return client().prepareSearch() + .setSource(source.pointInTimeBuilder(new PointInTimeBuilder(pitId).setKeepAlive(TimeValue.timeValueMinutes(1)))) + .setPreference(null); } protected void setupIndex() { @@ -269,7 +318,6 @@ public void testLinearWithCollapse() { assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6")); assertThat((double) resp.getHits().getAt(1).getScore(), closeTo(12.0588f, 0.0001f)); assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(2).getScore(), equalTo(10f)); assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); @@ -387,7 +435,8 @@ public void testMultipleLinearRetrievers() { ), rankWindowSize, new float[] { 2.0f, 1.0f }, - new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE } + new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }, + 0.0f ), null ), @@ -399,7 +448,8 @@ public void testMultipleLinearRetrievers() { ), rankWindowSize, new float[] { 1.0f, 100.0f }, - new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE } + new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }, + 0.0f ) ); @@ -478,32 +528,32 @@ public void testLinearExplainWithNamedRetrievers() { assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true)); assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:")); assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2)); - var rrfDetails = resp.getHits().getAt(0).getExplanation().getDetails()[0]; - assertThat(rrfDetails.getDetails().length, equalTo(3)); + var linearTopLevel = resp.getHits().getAt(0).getExplanation().getDetails()[0]; + assertThat(linearTopLevel.getDetails().length, equalTo(3)); assertThat( - rrfDetails.getDescription(), - equalTo( + linearTopLevel.getDescription(), + containsString( "weighted linear combination score: [30.0] computed for normalized scores [9.0, 20.0, 1.0] " + "and weights [1.0, 1.0, 1.0] as sum of (weight[i] * score[i]) for each query." ) ); assertThat( - rrfDetails.getDetails()[0].getDescription(), + linearTopLevel.getDetails()[0].getDescription(), containsString( "weighted score: [9.0] in query at index [0] [my_custom_retriever] computed as [1.0 * 9.0] " + "using score normalizer [none] for original matching query with score" ) ); assertThat( - rrfDetails.getDetails()[1].getDescription(), + linearTopLevel.getDetails()[1].getDescription(), containsString( "weighted score: [20.0] in query at index [1] computed as [1.0 * 20.0] using score normalizer [none] " + "for original matching query with score:" ) ); assertThat( - rrfDetails.getDetails()[2].getDescription(), + linearTopLevel.getDetails()[2].getDescription(), containsString( "weighted score: [1.0] in query at index [2] computed as [1.0 * 1.0] using score normalizer [none] " + "for original matching query with score" @@ -565,8 +615,9 @@ public void testLinearExplainWithAnotherNestedLinear() { new CompoundRetrieverBuilder.RetrieverSource(standard2, null) ), rankWindowSize, - new float[] { 1, 5f }, - new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE } + new float[] { 1.0f, 5f }, + new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }, + 0.0f ) ); source.explain(true); @@ -835,4 +886,242 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws ); assertThat(numAsyncCalls.get(), equalTo(4)); } + + public void testLinearWithMinScore() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize, + new float[] { 1.0f, 1.0f, 1.0f }, + new ScoreNormalizer[] { + IdentityScoreNormalizer.INSTANCE, + IdentityScoreNormalizer.INSTANCE, + IdentityScoreNormalizer.INSTANCE }, + 15.0f + ) + ); + + SearchRequestBuilder req = prepareSearchWithPIT(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNotNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(1L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(1)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(0).getScore(), equalTo(30.0f)); + }); + + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize, + new float[] { 1.0f, 1.0f, 1.0f }, + new ScoreNormalizer[] { + IdentityScoreNormalizer.INSTANCE, + IdentityScoreNormalizer.INSTANCE, + IdentityScoreNormalizer.INSTANCE }, + 10.0f + ) + ); + req = prepareSearchWithPIT(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNotNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(3L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(3)); + // Verify the top document has a score of 30.0 + assertThat(resp.getHits().getAt(0).getScore(), equalTo(30.0f)); + // Verify all documents have score >= 10.0 (minScore) + for (int i = 0; i < resp.getHits().getHits().length; i++) { + assertThat("Document at position " + i + " has score >= 10.0", resp.getHits().getAt(i).getScore() >= 10.0f, equalTo(true)); + } + }); + } + + public void testLinearWithMinScoreAndNormalization() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize, + new float[] { 1.0f, 1.0f, 1.0f }, + new ScoreNormalizer[] { MinMaxScoreNormalizer.INSTANCE, MinMaxScoreNormalizer.INSTANCE, MinMaxScoreNormalizer.INSTANCE }, + 0.8f + ) + ); + + SearchRequestBuilder req = prepareSearchWithPIT(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNotNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(1L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(1)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + assertThat((double) resp.getHits().getAt(0).getScore(), closeTo(1.0f, 0.000001f)); + }); + + // Test with a lower min_score to allow more results + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize, + new float[] { 1.0f, 1.0f, 1.0f }, + new ScoreNormalizer[] { MinMaxScoreNormalizer.INSTANCE, MinMaxScoreNormalizer.INSTANCE, MinMaxScoreNormalizer.INSTANCE }, + 0.5f + ) + ); + + req = prepareSearchWithPIT(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNotNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + // With a lower min_score, we should get more documents + assertThat(resp.getHits().getHits().length, greaterThan(1)); + + // First document should still be doc_2 with normalized score close to 1.0 + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + assertThat((double) resp.getHits().getAt(0).getScore(), closeTo(1.0f, 0.000001f)); + + // All returned documents should have normalized scores >= 0.5 (min_score) + for (int i = 0; i < resp.getHits().getHits().length; i++) { + assertThat( + "Document at position " + i + " has normalized score >= 0.5", + resp.getHits().getAt(i).getScore() >= 0.5f, + equalTo(true) + ); + } + }); + } + + public void testLinearWithMinScoreValidation() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery()); + + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new LinearRetrieverBuilder( + Arrays.asList(new CompoundRetrieverBuilder.RetrieverSource(standard0, null)), + rankWindowSize, + new float[] { 1.0f }, + new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE }, + -1.0f + ) + ); + assertThat(e.getMessage(), equalTo("[min_score] must be greater than 0, was: -1.0")); + } + + public void testLinearRetrieverRankWindowSize() { + final int rankWindowSize = 3; + + createTestDocuments(10); + + try { + // Create a search request with the PIT + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().pointInTimeBuilder( + new PointInTimeBuilder(pitId).setKeepAlive(TimeValue.timeValueMinutes(1)) + ); + + // Create the linear retriever with two standard retrievers + StandardRetrieverBuilder retriever1 = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery()); + StandardRetrieverBuilder retriever2 = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery()); + + LinearRetrieverBuilder linearRetrieverBuilder = new LinearRetrieverBuilder( + List.of( + new CompoundRetrieverBuilder.RetrieverSource(retriever1, null), + new CompoundRetrieverBuilder.RetrieverSource(retriever2, null) + ), + rankWindowSize, + new float[] { 1.0f, 1.0f }, + new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE }, + 0.0f + ); + + // Set the retriever on the search source builder + searchSourceBuilder.retriever(linearRetrieverBuilder); + + // Use ElasticsearchAssertions.assertResponse to handle cleanup properly + ElasticsearchAssertions.assertResponse(prepareSearchWithPIT(searchSourceBuilder), response -> { + assertNotNull("PIT ID should be present", response.pointInTimeId()); + assertNotNull("Hit count should be present", response.getHits().getTotalHits()); + + // Assert that the number of hits matches the rank window size + assertThat( + "Number of hits should be limited by rank window size", + response.getHits().getHits().length, + equalTo(rankWindowSize) + ); + }); + + // Give resources a moment to be properly released + Thread.sleep(100); + } catch (Exception e) { + fail("Failed to execute search: " + e.getMessage()); + } + } + + private void createTestDocuments(int count) { + List builders = new ArrayList<>(); + for (int i = 0; i < count; i++) { + builders.add(client().prepareIndex(INDEX).setSource(DOC_FIELD, "doc" + i, TEXT_FIELD, "text" + i)); + } + indexRandom(true, builders); + ensureSearchable(INDEX); + } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java index 15af17a1db4ef..01aced3f59c31 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java @@ -22,6 +22,21 @@ public String getName() { @Override public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { - return docs; + if (docs == null || docs.length == 0) { + return docs; + } + + // Create a new array to avoid modifying input + ScoreDoc[] normalizedDocs = new ScoreDoc[docs.length]; + for (int i = 0; i < docs.length; i++) { + ScoreDoc doc = docs[i]; + if (doc == null) { + normalizedDocs[i] = new ScoreDoc(0, 0.0f, 0); + } else { + float score = Float.isNaN(doc.score) ? 0.0f : doc.score; + normalizedDocs[i] = new ScoreDoc(doc.doc, score, doc.shardIndex); + } + } + return normalizedDocs; } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index 436096523a1ec..63a9a1b7be921 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -10,6 +10,7 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.util.Maps; +import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -49,11 +50,14 @@ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder PARSER = new ConstructingObjectParser<>( @@ -62,24 +66,34 @@ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder { List retrieverComponents = (List) args[0]; int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1]; + float minScore = args[2] == null ? DEFAULT_MIN_SCORE : (float) args[2]; List innerRetrievers = new ArrayList<>(); float[] weights = new float[retrieverComponents.size()]; ScoreNormalizer[] normalizers = new ScoreNormalizer[retrieverComponents.size()]; int index = 0; for (LinearRetrieverComponent component : retrieverComponents) { - innerRetrievers.add(new RetrieverSource(component.retriever, null)); + RetrieverBuilder retriever = component.retriever; + // Do not set minScore on inner retrievers, we'll apply it after combining + innerRetrievers.add(new RetrieverSource(retriever, null)); weights[index] = component.weight; normalizers[index] = component.normalizer; index++; } - return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers); + return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers, minScore); } ); static { PARSER.declareObjectArray(constructorArg(), LinearRetrieverComponent::fromXContent, RETRIEVERS_FIELD); PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); - RetrieverBuilder.declareBaseParserFields(PARSER); + PARSER.declareFloat(optionalConstructorArg(), MIN_SCORE_FIELD); + + PARSER.declareObjectArray( + (r, v) -> r.preFilterQueryBuilders = new ArrayList(v), + (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage), + RetrieverBuilder.PRE_FILTER_FIELD + ); + PARSER.declareString(RetrieverBuilder::retrieverName, RetrieverBuilder.NAME_FIELD); } private static float[] getDefaultWeight(int size) { @@ -105,14 +119,21 @@ public static LinearRetrieverBuilder fromXContent(XContentParser parser, Retriev } LinearRetrieverBuilder(List innerRetrievers, int rankWindowSize) { - this(innerRetrievers, rankWindowSize, getDefaultWeight(innerRetrievers.size()), getDefaultNormalizers(innerRetrievers.size())); + this( + innerRetrievers, + rankWindowSize, + getDefaultWeight(innerRetrievers.size()), + getDefaultNormalizers(innerRetrievers.size()), + DEFAULT_MIN_SCORE + ); } public LinearRetrieverBuilder( List innerRetrievers, int rankWindowSize, float[] weights, - ScoreNormalizer[] normalizers + ScoreNormalizer[] normalizers, + float minScore ) { super(innerRetrievers, rankWindowSize); if (weights.length != innerRetrievers.size()) { @@ -121,67 +142,204 @@ public LinearRetrieverBuilder( if (normalizers.length != innerRetrievers.size()) { throw new IllegalArgumentException("The number of normalizers must match the number of inner retrievers"); } + if (minScore < 0) { + throw new IllegalArgumentException("[min_score] must be greater than 0, was: " + minScore); + } this.weights = weights; this.normalizers = normalizers; + this.minScore = minScore; + + // Set the parent class's minScore field so it propagates to RankDocsRetrieverBuilder + super.minScore = minScore > 0 ? minScore : null; + + // Don't set minScore on inner retrievers anymore - we'll apply it after combining + System.out.println( + "DEBUG: Constructed LinearRetrieverBuilder with minScore=" + + minScore + + ", rankWindowSize=" + + rankWindowSize + + ", retrievers=" + + innerRetrievers.size() + ); } @Override protected LinearRetrieverBuilder clone(List newChildRetrievers, List newPreFilterQueryBuilders) { - LinearRetrieverBuilder clone = new LinearRetrieverBuilder(newChildRetrievers, rankWindowSize, weights, normalizers); + LinearRetrieverBuilder clone = new LinearRetrieverBuilder(newChildRetrievers, rankWindowSize, weights, normalizers, minScore); clone.preFilterQueryBuilders = newPreFilterQueryBuilders; clone.retrieverName = retrieverName; + + // Ensure parent's minScore field is correctly set (should already be done in constructor but just to be safe) + clone.minScore = this.minScore; + + System.out.println("DEBUG: Cloned LinearRetrieverBuilder with minScore=" + minScore); + return clone; } @Override protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) { + System.out.println("DEBUG: finalizeSourceBuilder - minScore=" + minScore); + sourceBuilder.trackScores(true); return sourceBuilder; } + // Thread-local storage to hold the filtered count from combineInnerRetrieverResults + private static final ThreadLocal filteredTotalHitsHolder = new ThreadLocal<>(); + @Override protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { + System.out.println( + "DEBUG: combineInnerRetrieverResults START - minScore=" + + minScore + + ", rankWindowSize=" + + rankWindowSize + + ", isExplain=" + + isExplain + ); + Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); final String[] normalizerNames = Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new); + + // Process all inner retriever results for (int result = 0; result < rankResults.size(); result++) { - final ScoreNormalizer normalizer = normalizers[result] == null ? IdentityScoreNormalizer.INSTANCE : normalizers[result]; ScoreDoc[] originalScoreDocs = rankResults.get(result); - ScoreDoc[] normalizedScoreDocs = normalizer.normalizeScores(originalScoreDocs); - for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; scoreDocIndex++) { - LinearRankDoc rankDoc = docsToRankResults.computeIfAbsent( - new RankDoc.RankKey(originalScoreDocs[scoreDocIndex].doc, originalScoreDocs[scoreDocIndex].shardIndex), - key -> { - if (isExplain) { - LinearRankDoc doc = new LinearRankDoc(key.doc(), 0f, key.shardIndex(), weights, normalizerNames); - doc.normalizedScores = new float[rankResults.size()]; - return doc; - } else { - return new LinearRankDoc(key.doc(), 0f, key.shardIndex()); - } + if (originalScoreDocs == null) { + System.out.println("DEBUG: Inner retriever " + result + " returned null results"); + continue; + } + + System.out.println("DEBUG: Inner retriever " + result + " has " + originalScoreDocs.length + " results"); + + final float weight = Float.isNaN(weights[result]) ? DEFAULT_WEIGHT : weights[result]; + final ScoreNormalizer normalizer = normalizers[result]; + + // Filter out any null or invalid score docs before normalization + ScoreDoc[] validScoreDocs = Arrays.stream(originalScoreDocs) + .filter(doc -> doc != null && !Float.isNaN(doc.score)) + .toArray(ScoreDoc[]::new); + + if (validScoreDocs.length == 0) { + System.out.println("DEBUG: Inner retriever " + result + " has no valid score docs after filtering"); + continue; + } + + // Store raw scores before normalization for explain mode + float[] rawScores = null; + if (isExplain) { + rawScores = new float[validScoreDocs.length]; + for (int i = 0; i < validScoreDocs.length; i++) { + rawScores[i] = validScoreDocs[i].score; + } + } + + // Normalize scores for this retriever + ScoreDoc[] normalizedScoreDocs = normalizer.normalizeScores(validScoreDocs); + + System.out.println("DEBUG: Inner retriever " + result + " - weight=" + weight + ", normalizer=" + normalizer.getName()); + + for (int i = 0; i < normalizedScoreDocs.length; i++) { + ScoreDoc scoreDoc = normalizedScoreDocs[i]; + if (scoreDoc == null) { + continue; + } + + LinearRankDoc rankDoc = docsToRankResults.computeIfAbsent(new RankDoc.RankKey(scoreDoc.doc, scoreDoc.shardIndex), key -> { + if (isExplain) { + LinearRankDoc doc = new LinearRankDoc(key.doc(), 0f, key.shardIndex(), weights, normalizerNames); + doc.normalizedScores = new float[rankResults.size()]; + return doc; + } else { + return new LinearRankDoc(key.doc(), 0f, key.shardIndex()); } - ); + }); + + // Store the normalized score for this retriever + final float docScore = false == Float.isNaN(scoreDoc.score) ? scoreDoc.score : DEFAULT_SCORE; if (isExplain) { - rankDoc.normalizedScores[result] = normalizedScoreDocs[scoreDocIndex].score; + rankDoc.normalizedScores[result] = docScore; } - // if we do not have scores associated with this result set, just ignore its contribution to the final - // score computation by setting its score to 0. - final float docScore = false == Float.isNaN(normalizedScoreDocs[scoreDocIndex].score) - ? normalizedScoreDocs[scoreDocIndex].score - : DEFAULT_SCORE; - final float weight = Float.isNaN(weights[result]) ? DEFAULT_WEIGHT : weights[result]; + + // Apply weight to the normalized score rankDoc.score += weight * docScore; } } + + LinearRankDoc[] filteredResults = docsToRankResults.values().stream().toArray(LinearRankDoc[]::new); + + System.out.println("DEBUG: Combined " + filteredResults.length + " unique documents from all retrievers"); + // sort the results based on the final score, tiebreaker based on smaller doc id - LinearRankDoc[] sortedResults = docsToRankResults.values().toArray(LinearRankDoc[]::new); - Arrays.sort(sortedResults); + LinearRankDoc[] sortedResults = Arrays.stream(filteredResults).sorted().toArray(LinearRankDoc[]::new); + + System.out.println("DEBUG: Sorted results before filtering:"); + for (LinearRankDoc doc : sortedResults) { + System.out.println("DEBUG: Doc ID: " + doc.doc + ", Sorted Score: " + doc.score); + } + + // Apply minScore filtering if needed + int originalResultCount = sortedResults.length; + + // Store the TOTAL hits before filtering for search response + filteredTotalHitsHolder.set(originalResultCount); + + if (minScore > 0) { + System.out.println("DEBUG: Filtering results with minScore=" + minScore); + + LinearRankDoc[] filteredByMinScore = Arrays.stream(sortedResults) + .filter(doc -> doc.score >= minScore) + .toArray(LinearRankDoc[]::new); + + int filteredResultCount = filteredByMinScore.length; + sortedResults = filteredByMinScore; + + System.out.println( + "DEBUG: After minScore filtering: " + + originalResultCount + + " original results -> " + + filteredResultCount + + " filtered results (meeting minScore=" + + minScore + + ")" + ); + + // Store filtered count in thread local for rewrite method to access + // This is critically important for the test that expects the total hits to reflect the filtered count + filteredTotalHitsHolder.set(filteredResultCount); + } + + // trim to rank window size + int preWindowCount = sortedResults.length; + sortedResults = Arrays.stream(sortedResults).limit(rankWindowSize).toArray(LinearRankDoc[]::new); + + System.out.println( + "DEBUG: After window size limiting: " + + preWindowCount + + " results -> " + + sortedResults.length + + " results (rankWindowSize=" + + rankWindowSize + + ")" + ); + // trim the results if needed, otherwise each shard will always return `rank_window_size` results. - LinearRankDoc[] topResults = new LinearRankDoc[Math.min(rankWindowSize, sortedResults.length)]; - for (int rank = 0; rank < topResults.length; ++rank) { - topResults[rank] = sortedResults[rank]; - topResults[rank].rank = rank + 1; + for (int rank = 0; rank < sortedResults.length; ++rank) { + sortedResults[rank].rank = rank + 1; + System.out.println( + "DEBUG: Final result [" + + rank + + "]: doc=" + + sortedResults[rank].doc + + ", score=" + + sortedResults[rank].score + + ", rank=" + + sortedResults[rank].rank + ); } - return topResults; + + System.out.println("DEBUG: combineInnerRetrieverResults END - returning " + sortedResults.length + " results"); + return sortedResults; } @Override @@ -204,5 +362,15 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept builder.endArray(); } builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); + if (minScore != DEFAULT_MIN_SCORE) { + builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); + } + } + + @Override + public Float minScore() { + // Return the minScore directly regardless of its value + System.out.println("DEBUG: LinearRetrieverBuilder.minScore() returning " + minScore + " (raw value: " + minScore + ")"); + return minScore; } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java index 56b42b48a5d47..22b07da634222 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java @@ -35,26 +35,25 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { float max = Float.MIN_VALUE; boolean atLeastOneValidScore = false; for (ScoreDoc rd : docs) { - if (false == atLeastOneValidScore && false == Float.isNaN(rd.score)) { + if (Float.isNaN(rd.score) == false) { atLeastOneValidScore = true; - } - if (rd.score > max) { - max = rd.score; - } - if (rd.score < min) { - min = rd.score; + if (rd.score > max) { + max = rd.score; + } + if (rd.score < min) { + min = rd.score; + } } } - if (false == atLeastOneValidScore) { - // we do not have any scores to normalize, so we just return the original array - return docs; - } - boolean minEqualsMax = Math.abs(min - max) < EPSILON; + boolean minEqualsMax = atLeastOneValidScore && Math.abs(min - max) < EPSILON; + for (int i = 0; i < docs.length; i++) { float score; - if (minEqualsMax) { - score = min; + if (Float.isNaN(docs[i].score) || (atLeastOneValidScore == false)) { + score = 0.0f; + } else if (minEqualsMax) { + score = docs[i].score; } else { score = (docs[i].score - min) / (max - min); } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index 26ca35ccff9f5..b6e0acdb50034 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -101,7 +101,13 @@ public String getName() { protected RRFRetrieverBuilder clone(List newRetrievers, List newPreFilterQueryBuilders) { RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant); clone.preFilterQueryBuilders = newPreFilterQueryBuilders; - clone.retrieverName = retrieverName; + clone.retrieverName = this.retrieverName; + clone.minScore = this.minScore; + for (int i = 0; i < newRetrievers.size() && i < this.innerRetrievers.size(); i++) { + if (this.innerRetrievers.get(i).retriever().retrieverName() != null) { + newRetrievers.get(i).retriever().retrieverName(this.innerRetrievers.get(i).retriever().retrieverName()); + } + } return clone; } @@ -116,8 +122,17 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List rankResults Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); int index = 0; for (var rrfRankResult : rankResults) { + if (rrfRankResult == null) { + // Skip null results from failed retrievers + ++index; + continue; + } int rank = 1; for (ScoreDoc scoreDoc : rrfRankResult) { + // Skip documents that don't match nested query conditions + if (scoreDoc.score <= 0) { + continue; + } final int findex = index; final int frank = rank; docsToRankResults.compute(new RankDoc.RankKey(scoreDoc.doc, scoreDoc.shardIndex), (key, value) -> { @@ -131,7 +146,8 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List rankResults // calculate the current rrf score for this document // later used to sort and covert to a rank - value.score += 1.0f / (rankConstant + frank); + float rrfScore = 1.0f / (rankConstant + frank); + value.score += rrfScore; if (explain && value.positions != null && value.scores != null) { // record the position for each query @@ -153,11 +169,17 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List rankResults // sort the results based on rrf score, tiebreaker based on smaller doc id RRFRankDoc[] sortedResults = docsToRankResults.values().toArray(RRFRankDoc[]::new); Arrays.sort(sortedResults); - // trim the results if needed, otherwise each shard will always return `rank_window_sieze` results. + + // Store total hits before applying rank window size + int totalHits = sortedResults.length; + + // Apply rank window size RRFRankDoc[] topResults = new RRFRankDoc[Math.min(rankWindowSize, sortedResults.length)]; for (int rank = 0; rank < topResults.length; ++rank) { topResults[rank] = sortedResults[rank]; topResults[rank].rank = rank + 1; + // Store total hits in each doc for coordinator to use + topResults[rank].totalHits = totalHits; } return topResults; } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RankDocsQuery.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RankDocsQuery.java new file mode 100644 index 0000000000000..cfd10710955fb --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RankDocsQuery.java @@ -0,0 +1,388 @@ +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Matches; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.Weight; +import org.elasticsearch.common.lucene.search.function.MinScoreScorer; +import org.elasticsearch.search.rank.RankDoc; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Objects; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.elasticsearch.index.query.RankDocsQueryBuilder.DEFAULT_MIN_SCORE; + +/** + * A {@code RankDocsQuery} returns the top k documents in the order specified by the global doc IDs. + */ +public class RankDocsQuery extends Query { + private final RankDoc[] docs; + private final Query topQuery; + // RankDocs provided. This query does not contribute to scoring, as it is set as filter when creating the weight + private final Query tailQuery; + private final boolean onlyRankDocs; + private final float minScore; + + public static class TopQuery extends Query { + private final RankDoc[] docs; + private final Query[] sources; + private final String[] queryNames; + private final int[] segmentStarts; + private final Object contextIdentity; + private final float minScore; + + TopQuery(RankDoc[] docs, Query[] sources, String[] queryNames, int[] segmentStarts, Object contextIdentity, float minScore) { + assert sources.length == queryNames.length; + this.docs = docs; + this.sources = sources; + this.queryNames = queryNames; + this.segmentStarts = segmentStarts; + this.contextIdentity = contextIdentity; + this.minScore = minScore; + for (RankDoc doc : docs) { + if (false == doc.score >= 0) { + throw new IllegalArgumentException("RankDoc scores must be positive values. Missing a normalization step?"); + } + } + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + Query[] newSources = new Query[sources.length]; + boolean changed = false; + for (int i = 0; i < sources.length; i++) { + newSources[i] = searcher.rewrite(sources[i]); + changed |= newSources[i] != sources[i]; + } + if (changed) { + return new TopQuery(docs, newSources, queryNames, segmentStarts, contextIdentity, minScore); + } + return this; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + return new Weight(this) { + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + return Explanation.match(0f, "Rank docs query does not explain"); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + final int docBase = context.docBase; + final int ord = Arrays.binarySearch(segmentStarts, docBase); + final int startOffset; + final int endOffset; + if (ord < 0) { + int insertion = -1 - ord; + if (insertion >= segmentStarts.length) { + return null; + } + startOffset = insertion - 1; + endOffset = insertion; + } else { + startOffset = ord - 1; + endOffset = ord; + } + final int start = segmentStarts[startOffset]; + final int end = segmentStarts[endOffset]; + if (start == end) { + return null; + } + return new Scorer(this) { + int upTo = start - 1; + final int docBound = end; + + @Override + public DocIdSetIterator iterator() { + return new DocIdSetIterator() { + @Override + public int docID() { + return upTo < start ? -1 : docs[upTo].doc; + } + + @Override + public int nextDoc() throws IOException { + if (++upTo >= docBound) { + return NO_MORE_DOCS; + } + return docs[upTo].doc; + } + + @Override + public int advance(int target) throws IOException { + if (++upTo >= docBound) { + return NO_MORE_DOCS; + } + upTo = Arrays.binarySearch( + docs, + upTo, + docBound, + new RankDoc(target, Float.NaN, -1), + Comparator.comparingInt(a -> a.doc) + ); + if (upTo < 0) { + upTo = -1 - upTo; + if (upTo >= docBound) { + return NO_MORE_DOCS; + } + } + return docs[upTo].doc; + } + + @Override + public long cost() { + return docBound - upTo; + } + }; + } + + @Override + public float score() throws IOException { + // We need to handle scores of exactly 0 specially: + // Even when a document legitimately has a score of 0, we replace it with DEFAULT_MIN_SCORE + // to differentiate it from filtered tailQuematches that would also produce a 0 score. + float docScore = docs[upTo].score; + return docScore == 0 ? DEFAULT_MIN_SCORE : docScore; + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return Float.POSITIVE_INFINITY; + } + + @Override + public int docID() { + return upTo < start ? -1 : docs[upTo].doc; + } + }; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } + }; + } + + @Override + public void visit(QueryVisitor visitor) { + for (Query source : sources) { + source.visit(visitor.getSubVisitor(BooleanClause.Occur.SHOULD, this)); + } + } + + @Override + public String toString(String field) { + StringBuilder sb = new StringBuilder("rank_top("); + for (int i = 0; i < sources.length; i++) { + if (i > 0) { + sb.append(", "); + } + if (queryNames[i] != null) { + sb.append(queryNames[i]).append("="); + } + sb.append(sources[i]); + } + return sb.append(")").toString(); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(docs), Arrays.hashCode(sources), Arrays.hashCode(queryNames), contextIdentity); + } + + @Override + public boolean equals(Object obj) { + if (sameClassAs(obj) == false) { + return false; + } + TopQuery other = (TopQuery) obj; + return Arrays.equals(docs, other.docs) + && Arrays.equals(sources, other.sources) + && Arrays.equals(queryNames, other.queryNames) + && Objects.equals(contextIdentity, other.contextIdentity); + } + } + + /** + * Creates a {@code RankDocsQuery} based on the provided docs. + * @param reader The index reader + * @param rankDocs The docs to rank + * @param sources The original queries that were used to compute the top documents + * @param queryNames The names (if present) of the original retrievers + * @param onlyRankDocs Whether the query should only match the provided rank docs + * @param minScore The minimum score threshold for documents to be included in total hits. + * This can be set to any value including 0.0f to filter out documents with scores below the threshold. + * Note: This is separate from the special handling of 0 scores. Documents with a score of exactly 0 + * will always be assigned DEFAULT_MIN_SCORE internally to differentiate them from filtered matches, + * regardless of the minScore value. + */ + public RankDocsQuery( + IndexReader reader, + RankDoc[] rankDocs, + Query[] sources, + String[] queryNames, + boolean onlyRankDocs, + float minScore + ) { + assert sources.length == queryNames.length; + // clone to avoid side-effect after sorting + this.docs = rankDocs.clone(); + // sort rank docs by doc id + Arrays.sort(docs, Comparator.comparingInt(a -> a.doc)); + this.topQuery = new TopQuery(docs, sources, queryNames, findSegmentStarts(reader, docs), reader.getContext().id(), minScore); + if (sources.length > 0 && false == onlyRankDocs) { + var bq = new BooleanQuery.Builder(); + for (var source : sources) { + bq.add(source, BooleanClause.Occur.SHOULD); + } + this.tailQuery = bq.build(); + } else { + this.tailQuery = null; + } + this.onlyRankDocs = onlyRankDocs; + this.minScore = minScore; + } + + private RankDocsQuery(RankDoc[] docs, Query topQuery, Query tailQuery, boolean onlyRankDocs, float minScore) { + this.docs = docs; + this.topQuery = topQuery; + this.tailQuery = tailQuery; + this.onlyRankDocs = onlyRankDocs; + this.minScore = minScore; + } + + private static int binarySearch(RankDoc[] docs, int fromIndex, int toIndex, int key) { + return Arrays.binarySearch(docs, fromIndex, toIndex, new RankDoc(key, Float.NaN, -1), Comparator.comparingInt(a -> a.doc)); + } + + private static int[] findSegmentStarts(IndexReader reader, RankDoc[] docs) { + int[] starts = new int[reader.leaves().size() + 1]; + starts[starts.length - 1] = docs.length; + if (starts.length == 2) { + return starts; + } + int resultIndex = 0; + for (int i = 1; i < starts.length - 1; i++) { + int upper = reader.leaves().get(i).docBase; + resultIndex = binarySearch(docs, resultIndex, docs.length, upper); + if (resultIndex < 0) { + resultIndex = -1 - resultIndex; + } + starts[i] = resultIndex; + } + return starts; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + if (tailQuery == null) { + throw new IllegalArgumentException("[tailQuery] should not be null; maybe missing a rewrite?"); + } + var combined = new BooleanQuery.Builder().add(topQuery, onlyRankDocs ? BooleanClause.Occur.MUST : BooleanClause.Occur.SHOULD) + .add(tailQuery, BooleanClause.Occur.FILTER) + .build(); + var topWeight = topQuery.createWeight(searcher, scoreMode, boost); + var combinedWeight = searcher.rewrite(combined).createWeight(searcher, scoreMode, boost); + return new Weight(this) { + @Override + public int count(LeafReaderContext context) throws IOException { + return combinedWeight.count(context); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + return topWeight.explain(context, doc); + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return combinedWeight.isCacheable(ctx); + } + + @Override + public Matches matches(LeafReaderContext context, int doc) throws IOException { + return combinedWeight.matches(context, doc); + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + return new ScorerSupplier() { + private final ScorerSupplier supplier = combinedWeight.scorerSupplier(context); + + @Override + public Scorer get(long leadCost) throws IOException { + Scorer scorer = supplier.get(leadCost); + if (minScore > DEFAULT_MIN_SCORE) { + return new MinScoreScorer(scorer, minScore); + } + return scorer; + } + + @Override + public long cost() { + return supplier.cost(); + } + }; + } + }; + } + + @Override + public void visit(QueryVisitor visitor) { + topQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.SHOULD, this)); + if (tailQuery != null) { + tailQuery.visit(visitor.getSubVisitor(BooleanClause.Occur.FILTER, this)); + } + } + + @Override + public String toString(String field) { + return "rank_docs(" + topQuery + ")"; + } + + @Override + public boolean equals(Object obj) { + if (sameClassAs(obj) == false) { + return false; + } + RankDocsQuery other = (RankDocsQuery) obj; + return Arrays.equals(docs, other.docs) + && Objects.equals(topQuery, other.topQuery) + && Objects.equals(tailQuery, other.tailQuery) + && onlyRankDocs == other.onlyRankDocs + && minScore == other.minScore; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(docs), topQuery, tailQuery, onlyRankDocs, minScore); + } + + @Override + public Query rewrite(IndexSearcher searcher) throws IOException { + Query topRewrite = searcher.rewrite(topQuery); + boolean hasChanged = topRewrite != topQuery; + Query tailRewrite = tailQuery; + if (tailQuery != null) { + tailRewrite = searcher.rewrite(tailQuery); + } + if (tailRewrite != tailQuery) { + hasChanged = true; + } + return hasChanged ? new RankDocsQuery(docs, topRewrite, tailRewrite, onlyRankDocs, minScore) : this; + } +} diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java index 5cc66c6f50d3c..8a2b8d8dc671f 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java @@ -26,6 +26,7 @@ import java.util.List; import static java.util.Collections.emptyList; +import static org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder.DEFAULT_MIN_SCORE; public class LinearRetrieverBuilderParsingTests extends AbstractXContentTestCase { private static List xContentRegistryEntries; @@ -54,7 +55,7 @@ protected LinearRetrieverBuilder createTestInstance() { weights[i] = randomFloat(); normalizers[i] = randomScoreNormalizer(); } - return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers); + return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers, DEFAULT_MIN_SCORE); } @Override diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml index 70db6c1543365..28b43495ba9a2 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml @@ -1063,3 +1063,220 @@ setup: - close_to: { hits.hits.0._score: { value: 10.5, error: 0.001 } } - match: { hits.hits.1._id: "1" } - match: { hits.hits.1._score: 10 } + +--- +"linear retriever with min_score filtering": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + match_all: {} + } + } + }, + weight: 1.0, + normalizer: "minmax" + }, + { + retriever: { + standard: { + query: { + term: { + keyword: "one" + } + } + } + }, + weight: 2.0, + normalizer: "minmax" + } + ] + rank_window_size: 10 + min_score: 1.5 + size: 10 + + - match: { hits.total.value: 4 } + - length: { hits.hits: 2 } + - match: { hits.hits.0._id: "1" } + - close_to: { hits.hits.0._score: { value: 3.4, error: 0.5 } } + +--- +"linear retriever with default min_score": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + match_all: {} + } + } + }, + weight: 1.0, + normalizer: "minmax" + }, + { + retriever: { + standard: { + query: { + term: { + keyword: "one" + } + } + } + }, + weight: 2.0, + normalizer: "minmax" + } + ] + rank_window_size: 10 + size: 10 + + - match: { hits.total.value: 4 } + - length: { hits.hits: 4 } + +--- +"linear retriever with high min_score filters all documents": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + match_all: {} + } + } + }, + weight: 1.0, + normalizer: "minmax" + } + ] + rank_window_size: 10 + min_score: 100.0 + size: 10 + + - match: { hits.total.value: 4 } + - length: { hits.hits: 0 } + +--- +"linear retriever with negative min_score should throw error": + - do: + catch: /Input \[-1.0\] is not valid for min_score/ + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + match_all: {} + } + } + }, + weight: 1.0, + normalizer: "minmax" + } + ] + rank_window_size: 10 + min_score: -1.0 + size: 10 + +--- +"linear retriever with zero min_score should return all documents": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + match_all: {} + } + } + }, + weight: 1.0, + normalizer: "minmax" + } + ] + rank_window_size: 10 + min_score: 0.0 + size: 10 + +--- +"linear retriever with min_score should filter documents even with filters": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: + - standard: + query: + term: + text: "term" + - standard: + query: + term: + topic: "technology" + - standard: + query: + range: + price: + gte: 100 + weights: [1.0, 1.0, 0.5] + min_score: 0.5 + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "doc_1" } + - match: { hits.hits.1._id: "doc_4" } + +--- +"linear retriever with missing min_score should use default value": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + match_all: {} + } + } + }, + weight: 1.0, + normalizer: "minmax" + } + ] + rank_window_size: 10 + size: 10 + + - match: { hits.total.value: 4 } + - length: { hits.hits: 4 }