From 825683f58441576e8f417e80cef5ca3ed59cbde3 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 8 Jan 2025 00:33:07 +0200 Subject: [PATCH 01/57] iter --- .../elasticsearch/search/SearchModule.java | 2 + .../search/rank/LinearRankDoc.java | 88 +++++++ .../retriever/LinearRetrieverBuilder.java | 231 ++++++++++++++++++ .../search/retriever/RetrieversFeatures.java | 3 +- .../retriever/WrapperRetrieverBuilder.java | 75 ++++++ 5 files changed, 398 insertions(+), 1 deletion(-) create mode 100644 server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java create mode 100644 server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java create mode 100644 server/src/main/java/org/elasticsearch/search/retriever/WrapperRetrieverBuilder.java diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index d282ba425b126..591e44a2bf417 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -231,6 +231,7 @@ import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.retriever.LinearRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.StandardRetrieverBuilder; @@ -1080,6 +1081,7 @@ private void registerFetchSubPhase(FetchSubPhase subPhase) { private void registerRetrieverParsers(List plugins) { registerRetriever(new RetrieverSpec<>(StandardRetrieverBuilder.NAME, StandardRetrieverBuilder::fromXContent)); registerRetriever(new RetrieverSpec<>(KnnRetrieverBuilder.NAME, KnnRetrieverBuilder::fromXContent)); + registerRetriever(new RetrieverSpec<>(LinearRetrieverBuilder.NAME, LinearRetrieverBuilder::fromXContent)); registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever); } diff --git a/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java b/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java new file mode 100644 index 0000000000000..170ddcc44a6e8 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.rank; + +import org.apache.lucene.search.Explanation; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.util.Arrays; + +public class LinearRankDoc extends RankDoc { + + public float[] weights; + public float[] scores; + public String[] normalizers; + + public LinearRankDoc(int doc, float score, int shardIndex, int queriesCount) { + super(doc, score, shardIndex); + this.weights = new float[queriesCount]; + this.scores = new float[queriesCount]; + Arrays.fill(scores, 0f); + this.normalizers = new String[queriesCount]; + } + + public LinearRankDoc(StreamInput in) throws IOException { + super(in.readVInt(), in.readFloat(), in.readVInt()); + weights = in.readFloatArray(); + scores = in.readFloatArray(); + normalizers = in.readStringArray(); + } + + @Override + public Explanation explain(Explanation[] sources, String[] queryNames) { + Explanation[] details = new Explanation[sources.length]; + for (int i = 0; i < sources.length; i++) { + final String queryAlias = queryNames[i] == null ? "" : " [" + queryNames[i] + "]"; + final String queryIdentifier = "at index [" + i + "]" + queryAlias; + if (scores[i] > 0) { + details[i] = Explanation.match( + weights[i] * scores[i], + "weighted score: [" + + weights[i] * scores[i] + + "] in query " + + queryIdentifier + + " computed as [" + + weights[i] + + " * " + + scores[i] + + "]" + + " using score normalizer [" + + normalizers[i] + + "]" + + " for original matching query with score:", + sources[i] + ); + } else { + final String description = "weighted score: [0], result not found in query " + queryIdentifier; + details[i] = Explanation.noMatch(description); + } + } + return Explanation.match( + score, + "weighted linear combination score: [" + + score + + "] computed for normalized scores " + + Arrays.toString(scores) + + " and weights " + + Arrays.toString(weights) + + "] as sum of (weight[i] * score[i]) for each query.", + details + ); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeFloatArray(weights); + out.writeFloatArray(scores); + out.writeStringArray(normalizers); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java new file mode 100644 index 0000000000000..95cfc9edb0b59 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java @@ -0,0 +1,231 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.retriever; + +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.util.Maps; +import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.rank.LinearRankDoc; +import org.elasticsearch.search.rank.RankBuilder; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * The {@code LinearRetrieverBuilder} supports the combination of different retrievers through a weighted linear combination. + * For example, assume that we have retrievers r1 and r2, the final score of the {@code LinearRetrieverBuilder} is defined as + * {@code score(r)=w1*score(r1) + w2*score(r2)}. + * Each sub-retriever score can be normalized before being considered for the weighted linear sum, by setting the appropriate + * normalizer parameter. + * + */ +public class LinearRetrieverBuilder extends CompoundRetrieverBuilder { + + public static final String NAME = "linear_retriever"; + + public static final NodeFeature LINEAR_RETRIEVER_SUPPORTED = new NodeFeature("linear_retriever_support"); + + public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers"); + public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size"); + + private final List wrappedRetrievers; + + static final float DEFAULT_WEIGHT = 1f; + static final ScoreNormalizer DEFAULT_NORMALIZER = ScoreNormalizer.IDENTITY; + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + false, + args -> { + List childRetrievers = (List) args[0]; + List innerRetrievers = childRetrievers.stream().map(r -> new RetrieverSource(r.retriever, null)).toList(); + int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1]; + return new LinearRetrieverBuilder(childRetrievers, innerRetrievers, rankWindowSize); + } + ); + + // public record WrappedRetriever(RetrieverBuilder retrieverBuilder, float weight, ScoreNormalizer normalizer) {} + + static { + PARSER.declareObjectArray(constructorArg(), (p, c) -> { + // float weight = -1f; + // ScoreNormalizer normalizer = null; + // RetrieverBuilder retrieverBuilder = null; + // while (p.nextToken() != null && p.currentName() != null) { + // String name = p.currentName(); + // switch (name) { + // case "weight": + // p.nextToken(); + // weight = p.floatValue(); + // break; + // case "retriever": + // p.nextToken(); + // p.nextToken(); + // retrieverBuilder = p.namedObject(RetrieverBuilder.class, p.currentName(), c); + // c.trackRetrieverUsage(retrieverBuilder.getName()); + // p.nextToken(); + // break; + // case "normalizer": + // p.nextToken(); + // String normalizerName = p.text(); + // normalizer = ScoreNormalizer.find(normalizerName); + // break; + // default: + // throw new ParsingException(p.getTokenLocation(), "Unknown key {" + name + "} provided"); + // } + // } + // ; + // return new WrappedRetriever(retrieverBuilder, weight, normalizer); + p.nextToken(); + WrapperRetrieverBuilder retrieverBuilder = WrapperRetrieverBuilder.fromXContent(p, c); + p.nextToken(); + return retrieverBuilder; + }, RETRIEVERS_FIELD); + PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); + RetrieverBuilder.declareBaseParserFields(NAME, PARSER); + } + + public static LinearRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { + if (context.clusterSupportsFeature(LINEAR_RETRIEVER_SUPPORTED) == false) { + throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + NAME + "]"); + } + return PARSER.apply(parser, context); + } + + protected LinearRetrieverBuilder( + List wrappedRetrievers, + List innerRetrievers, + int rankWindowSize + ) { + super(innerRetrievers, rankWindowSize); + this.wrappedRetrievers = wrappedRetrievers; + } + + @Override + protected LinearRetrieverBuilder clone(List newChildRetrievers, List newPreFilterQueryBuilders) { + LinearRetrieverBuilder clone = new LinearRetrieverBuilder(wrappedRetrievers, newChildRetrievers, rankWindowSize); + clone.preFilterQueryBuilders = newPreFilterQueryBuilders; + return clone; + } + + @Override + protected RankDoc[] combineInnerRetrieverResults(List rankResults) { + Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); + for (int resIndex = 0; resIndex < rankResults.size(); resIndex++) { + ScoreDoc[] originalScoreDocs = rankResults.get(resIndex); + ScoreDoc[] normalizedScoreDocs = wrappedRetrievers.get(resIndex).normalizer.normalizeScores(originalScoreDocs); + for (int i = 0; i < normalizedScoreDocs.length; i++) { + int finalResIndex = resIndex; + int finalI = i; + docsToRankResults.compute(new RankDoc.RankKey(originalScoreDocs[i].doc, originalScoreDocs[i].shardIndex), (key, value) -> { + if (value == null) { + value = new LinearRankDoc( + originalScoreDocs[finalI].doc, + 0, + originalScoreDocs[finalI].shardIndex, + rankResults.size() + ); + } + value.scores[finalResIndex] = normalizedScoreDocs[finalI].score; + value.weights[finalResIndex] = wrappedRetrievers.get(finalResIndex).weight; + value.normalizers[finalResIndex] = wrappedRetrievers.get(finalResIndex).normalizer.name(); + value.score += wrappedRetrievers.get(finalResIndex).weight * normalizedScoreDocs[finalI].score; + return value; + }); + } + } + // sort the results based on rrf score, tiebreaker based on smaller doc id + LinearRankDoc[] sortedResults = docsToRankResults.values().toArray(LinearRankDoc[]::new); + Arrays.sort(sortedResults); + // trim the results if needed, otherwise each shard will always return `rank_window_size` results. + LinearRankDoc[] topResults = new LinearRankDoc[Math.min(rankWindowSize, sortedResults.length)]; + for (int rank = 0; rank < topResults.length; ++rank) { + topResults[rank] = sortedResults[rank]; + topResults[rank].rank = rank + 1; + } + return topResults; + } + + @Override + public String getName() { + return NAME; + } + + @Override + protected void doToXContent(XContentBuilder builder, Params params) throws IOException { + + } + + enum ScoreNormalizer { + IDENTITY("identity") { + @Override + public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { + // no-op + return docs; + } + }, + MINMAX("minmax") { + @Override + public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { + // create a new array to avoid changing ScoreDocs in place + ScoreDoc[] scoreDocs = new ScoreDoc[docs.length]; + // to avoid 0 scores + float epsilon = Float.MIN_NORMAL; + float min = Float.MAX_VALUE; + float max = Float.MIN_VALUE; + for (ScoreDoc rd : docs) { + if (rd.score > max) { + max = rd.score; + } + if (rd.score < min) { + min = rd.score; + } + } + for (int i = 0; i < docs.length; i++) { + float score = epsilon + ((docs[i].score - min) / (max - min)); + scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex); + } + return scoreDocs; + } + }; + + private final String name; + + ScoreNormalizer(String name) { + this.name = name; + } + + abstract ScoreDoc[] normalizeScores(ScoreDoc[] docs); + + static ScoreNormalizer find(String name) { + for (ScoreNormalizer normalizer : values()) { + if (normalizer.name.equalsIgnoreCase(name)) { + return normalizer; + } + } + throw new IllegalArgumentException( + "Unknown normalizer [" + name + "] provided. Supported values are: " + Arrays.stream(values()).map(Enum::name).toList() + ); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java index 74a8b30c8e7dc..5a98618e70bc3 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java @@ -25,7 +25,8 @@ public Set getFeatures() { return Set.of( RetrieverBuilder.RETRIEVERS_SUPPORTED, StandardRetrieverBuilder.STANDARD_RETRIEVER_SUPPORTED, - KnnRetrieverBuilder.KNN_RETRIEVER_SUPPORTED + KnnRetrieverBuilder.KNN_RETRIEVER_SUPPORTED, + LinearRetrieverBuilder.LINEAR_RETRIEVER_SUPPORTED ); } } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/WrapperRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/WrapperRetrieverBuilder.java new file mode 100644 index 0000000000000..dbd5d5e16491f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/WrapperRetrieverBuilder.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.retriever; + +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +import static org.elasticsearch.search.retriever.LinearRetrieverBuilder.DEFAULT_NORMALIZER; +import static org.elasticsearch.search.retriever.LinearRetrieverBuilder.DEFAULT_WEIGHT; +import static org.elasticsearch.search.retriever.LinearRetrieverBuilder.RETRIEVERS_FIELD; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class WrapperRetrieverBuilder implements ToXContentObject { + + public static final ParseField RETRIEVER_FIELD = new ParseField("retriever"); + public static final ParseField WEIGHT_FIELD = new ParseField("weight"); + public static final ParseField NORMALIZER_FIELD = new ParseField("normalizer"); + + public static final String NAME = "component"; + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + false, + args -> { + RetrieverBuilder base = (RetrieverBuilder) args[0]; + float weight = args[1] == null ? DEFAULT_WEIGHT : (float) args[1]; + LinearRetrieverBuilder.ScoreNormalizer normalizer = args[2] == null + ? DEFAULT_NORMALIZER + : LinearRetrieverBuilder.ScoreNormalizer.find((String) args[2]); + return new WrapperRetrieverBuilder(base, weight, normalizer); + } + ); + + static { + PARSER.declareNamedObject(constructorArg(), (p, c, n) -> { + RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, n, c); + c.trackRetrieverUsage(retrieverBuilder.getName()); + return retrieverBuilder; + }, RETRIEVER_FIELD); + PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD); + PARSER.declareString(optionalConstructorArg(), NORMALIZER_FIELD); + } + + RetrieverBuilder retriever; + float weight; + LinearRetrieverBuilder.ScoreNormalizer normalizer; + + public WrapperRetrieverBuilder(RetrieverBuilder base, float weight, LinearRetrieverBuilder.ScoreNormalizer normalizer) { + this.retriever = base; + this.weight = weight; + this.normalizer = normalizer; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(RETRIEVERS_FIELD.getPreferredName(), retriever); + return builder; + } + + public static WrapperRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { + return PARSER.apply(parser, context); + } +} From 466c02666ff076ac26fa282883e68e08549e79d9 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Mon, 13 Jan 2025 11:11:22 +0200 Subject: [PATCH 02/57] iter --- docs/reference/search/retriever.asciidoc | 21 +- .../search.retrievers/40_linear_retriever.yml | 239 ++++++++++++++++++ .../search/rank/LinearRankDoc.java | 42 ++- .../retriever/LinearRetrieverBuilder.java | 187 +++++--------- .../retriever/LinearRetrieverComponent.java | 121 +++++++++ .../retriever/WrapperRetrieverBuilder.java | 75 ------ .../LinearRetrieverBuilderParsingTests.java | 83 ++++++ 7 files changed, 559 insertions(+), 209 deletions(-) create mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml create mode 100644 server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java delete mode 100644 server/src/main/java/org/elasticsearch/search/retriever/WrapperRetrieverBuilder.java create mode 100644 server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java diff --git a/docs/reference/search/retriever.asciidoc b/docs/reference/search/retriever.asciidoc index 7e98297b780e6..28baee4d4c7ad 100644 --- a/docs/reference/search/retriever.asciidoc +++ b/docs/reference/search/retriever.asciidoc @@ -28,6 +28,9 @@ A <> that replaces the functionality of a traditi `knn`:: A <> that replaces the functionality of a <>. +`linear`:: +A <> that linearly combines the scores of other retrievers for the top documents. + `rescorer`:: A <> that replaces the functionality of the <>. @@ -263,6 +266,18 @@ GET /restaurants/_search This value must be fewer than or equal to `num_candidates`. <5> The size of the initial candidate set from which the final `k` nearest neighbors are selected. +[[linear-retriever]] +==== Linear Retriever +A retriever that normalizes and linearly combines the scores of other retrievers. + +===== Parameters + +`retrievers`:: +(Required, A list of <>) + +sth + ++ [[rrf-retriever]] ==== RRF Retriever @@ -576,7 +591,7 @@ This example demonstrates how to deploy the {ml-docs}/ml-nlp-rerank.html[Elastic Follow these steps: -. Create an inference endpoint for the `rerank` task using the <>. +. Create an inference endpoint for the `rerank` task using the <>. + [source,console] ---- @@ -584,7 +599,7 @@ PUT _inference/rerank/my-elastic-rerank { "service": "elasticsearch", "service_settings": { - "model_id": ".rerank-v1", + "model_id": ".rerank-v1", "num_threads": 1, "adaptive_allocations": { <1> "enabled": true, @@ -595,7 +610,7 @@ PUT _inference/rerank/my-elastic-rerank } ---- // TEST[skip:uses ML] -<1> {ml-docs}/ml-nlp-auto-scale.html#nlp-model-adaptive-allocations[Adaptive allocations] will be enabled with the minimum of 1 and the maximum of 10 allocations. +<1> {ml-docs}/ml-nlp-auto-scale.html#nlp-model-adaptive-allocations[Adaptive allocations] will be enabled with the minimum of 1 and the maximum of 10 allocations. + . Define a `text_similarity_rerank` retriever: + diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml new file mode 100644 index 0000000000000..8a12e54221c55 --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml @@ -0,0 +1,239 @@ +setup: + - requires: + cluster_features: [ "linear_retriever_supported" ] + reason: "Support for linear retriever" + test_runner_features: close_to + + - do: + indices.create: + index: test + body: + mappings: + properties: + vector: + type: dense_vector + dims: 1 + index: true + similarity: l2_norm + keyword: + type: keyword + + - do: + bulk: + refresh: true + index: test + body: + - '{"index": {"_id": 1 }}' + - '{"vector": [1], "keyword": "one"}' + - '{"index": {"_id": 2 }}' + - '{"vector": [2], "keyword": "two"}' + - '{"index": {"_id": 3 }}' + - '{"vector": [3], "keyword": "three"}' + - '{"index": {"_id": 4 }}' + - '{"vector": [4], "keyword": "four"}' +--- +"basic linear weighted combination of a standard and knn retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 0.5 + } + }, + { + component: { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + } + ] + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 5.0 } + - match: { hits.hits.1._id: "4" } + - match: { hits.hits.1._score: 2.0 } + +--- +"should normalize initial scores": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + } + ] + } + } + } + }, + weight: 10.0, + normalizer: "minmax" + } + }, + { + component: { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + } + ] + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "1" } + - close_to: { hits.hits.0._score: { value: 10.0, error: 0.001 } } + - match: { hits.hits.1._id: "2" } + - close_to: { hits.hits.1._score: { value: 8.0, error: 0.001 } } + - match: { hits.hits.2._id: "4" } + - match: { hits.hits.2._score: 2.0 } + - match: { hits.hits.3._id: "3" } + - close_to: { hits.hits.3._score: { value: 0.0, error: 0.001 } } + +--- +"should throw on unknown normalizer": + - do: + catch: /Unknown \[aardvark\] ScoreNormalizer provided/ + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 1.0, + normalizer: "aardvark" + } + }, + { + component: { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + } + ] + +#--- +#"pagination within a consistent rank_window_size": +# +#--- +#"explain should provide info on weights and inner retrievers": +# +#--- +#"collapsing results": +# +#--- +#"highlighting results": +# +#--- +#"multiple nested linear retrievers": +# +#--- +#"linear retriever with filters": +# +# +#--- +#"linear retriever with filters on nested retrievers": +# diff --git a/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java b/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java index 170ddcc44a6e8..6998e5d81c79c 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java @@ -12,28 +12,29 @@ import org.apache.lucene.search.Explanation; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; import java.util.Arrays; +import java.util.Objects; public class LinearRankDoc extends RankDoc { public float[] weights; - public float[] scores; + public float[] normalizedScores; public String[] normalizers; public LinearRankDoc(int doc, float score, int shardIndex, int queriesCount) { super(doc, score, shardIndex); this.weights = new float[queriesCount]; - this.scores = new float[queriesCount]; - Arrays.fill(scores, 0f); + this.normalizedScores = new float[queriesCount]; this.normalizers = new String[queriesCount]; } public LinearRankDoc(StreamInput in) throws IOException { super(in.readVInt(), in.readFloat(), in.readVInt()); weights = in.readFloatArray(); - scores = in.readFloatArray(); + normalizedScores = in.readFloatArray(); normalizers = in.readStringArray(); } @@ -43,17 +44,17 @@ public Explanation explain(Explanation[] sources, String[] queryNames) { for (int i = 0; i < sources.length; i++) { final String queryAlias = queryNames[i] == null ? "" : " [" + queryNames[i] + "]"; final String queryIdentifier = "at index [" + i + "]" + queryAlias; - if (scores[i] > 0) { + if (normalizedScores[i] > 0) { details[i] = Explanation.match( - weights[i] * scores[i], + weights[i] * normalizedScores[i], "weighted score: [" - + weights[i] * scores[i] + + weights[i] * normalizedScores[i] + "] in query " + queryIdentifier + " computed as [" + weights[i] + " * " - + scores[i] + + normalizedScores[i] + "]" + " using score normalizer [" + normalizers[i] @@ -71,7 +72,7 @@ public Explanation explain(Explanation[] sources, String[] queryNames) { "weighted linear combination score: [" + score + "] computed for normalized scores " - + Arrays.toString(scores) + + Arrays.toString(normalizedScores) + " and weights " + Arrays.toString(weights) + "] as sum of (weight[i] * score[i]) for each query.", @@ -82,7 +83,28 @@ public Explanation explain(Explanation[] sources, String[] queryNames) { @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeFloatArray(weights); - out.writeFloatArray(scores); + out.writeFloatArray(normalizedScores); out.writeStringArray(normalizers); } + + @Override + protected void doToXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("weights", weights); + builder.field("normalizedScores", normalizedScores); + builder.field("normalizers", normalizers); + } + + @Override + public boolean doEquals(RankDoc rd) { + LinearRankDoc lrd = (LinearRankDoc) rd; + return Arrays.equals(weights, lrd.weights) + && Arrays.equals(normalizedScores, lrd.normalizedScores) + && Arrays.equals(normalizers, lrd.normalizers); + } + + @Override + public int doHashCode() { + int result = Objects.hash(Arrays.hashCode(weights), Arrays.hashCode(normalizedScores) + Arrays.hashCode(normalizers)); + return 31 * result; + } } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java index 95cfc9edb0b59..764181ea5c038 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java @@ -23,6 +23,7 @@ import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -38,66 +39,44 @@ * normalizer parameter. * */ -public class LinearRetrieverBuilder extends CompoundRetrieverBuilder { +public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder { - public static final String NAME = "linear_retriever"; + public static final String NAME = "linear"; - public static final NodeFeature LINEAR_RETRIEVER_SUPPORTED = new NodeFeature("linear_retriever_support"); + public static final NodeFeature LINEAR_RETRIEVER_SUPPORTED = new NodeFeature("linear_retriever_supported"); public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers"); public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size"); - private final List wrappedRetrievers; - - static final float DEFAULT_WEIGHT = 1f; - static final ScoreNormalizer DEFAULT_NORMALIZER = ScoreNormalizer.IDENTITY; + private final float[] weights; + private final LinearRetrieverComponent.ScoreNormalizer[] normalizers; @SuppressWarnings("unchecked") static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, false, args -> { - List childRetrievers = (List) args[0]; - List innerRetrievers = childRetrievers.stream().map(r -> new RetrieverSource(r.retriever, null)).toList(); + List retrieverComponents = (List) args[0]; int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1]; - return new LinearRetrieverBuilder(childRetrievers, innerRetrievers, rankWindowSize); + List innerRetrievers = new ArrayList<>(); + float[] weights = new float[retrieverComponents.size()]; + LinearRetrieverComponent.ScoreNormalizer[] normalizers = new LinearRetrieverComponent.ScoreNormalizer[retrieverComponents + .size()]; + int index = 0; + for (LinearRetrieverComponent component : retrieverComponents) { + innerRetrievers.add(new RetrieverSource(component.retriever, null)); + weights[index] = component.weight; + normalizers[index] = component.normalizer; + index++; + } + return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers); } ); - // public record WrappedRetriever(RetrieverBuilder retrieverBuilder, float weight, ScoreNormalizer normalizer) {} - static { PARSER.declareObjectArray(constructorArg(), (p, c) -> { - // float weight = -1f; - // ScoreNormalizer normalizer = null; - // RetrieverBuilder retrieverBuilder = null; - // while (p.nextToken() != null && p.currentName() != null) { - // String name = p.currentName(); - // switch (name) { - // case "weight": - // p.nextToken(); - // weight = p.floatValue(); - // break; - // case "retriever": - // p.nextToken(); - // p.nextToken(); - // retrieverBuilder = p.namedObject(RetrieverBuilder.class, p.currentName(), c); - // c.trackRetrieverUsage(retrieverBuilder.getName()); - // p.nextToken(); - // break; - // case "normalizer": - // p.nextToken(); - // String normalizerName = p.text(); - // normalizer = ScoreNormalizer.find(normalizerName); - // break; - // default: - // throw new ParsingException(p.getTokenLocation(), "Unknown key {" + name + "} provided"); - // } - // } - // ; - // return new WrappedRetriever(retrieverBuilder, weight, normalizer); p.nextToken(); - WrapperRetrieverBuilder retrieverBuilder = WrapperRetrieverBuilder.fromXContent(p, c); + LinearRetrieverComponent retrieverBuilder = LinearRetrieverComponent.fromXContent(p, c); p.nextToken(); return retrieverBuilder; }, RETRIEVERS_FIELD); @@ -112,18 +91,20 @@ public static LinearRetrieverBuilder fromXContent(XContentParser parser, Retriev return PARSER.apply(parser, context); } - protected LinearRetrieverBuilder( - List wrappedRetrievers, + public LinearRetrieverBuilder( List innerRetrievers, - int rankWindowSize + int rankWindowSize, + float[] weights, + LinearRetrieverComponent.ScoreNormalizer[] normalizers ) { super(innerRetrievers, rankWindowSize); - this.wrappedRetrievers = wrappedRetrievers; + this.weights = weights; + this.normalizers = normalizers; } @Override protected LinearRetrieverBuilder clone(List newChildRetrievers, List newPreFilterQueryBuilders) { - LinearRetrieverBuilder clone = new LinearRetrieverBuilder(wrappedRetrievers, newChildRetrievers, rankWindowSize); + LinearRetrieverBuilder clone = new LinearRetrieverBuilder(newChildRetrievers, rankWindowSize, weights, normalizers); clone.preFilterQueryBuilders = newPreFilterQueryBuilders; return clone; } @@ -131,30 +112,33 @@ protected LinearRetrieverBuilder clone(List newChildRetrievers, @Override protected RankDoc[] combineInnerRetrieverResults(List rankResults) { Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); - for (int resIndex = 0; resIndex < rankResults.size(); resIndex++) { - ScoreDoc[] originalScoreDocs = rankResults.get(resIndex); - ScoreDoc[] normalizedScoreDocs = wrappedRetrievers.get(resIndex).normalizer.normalizeScores(originalScoreDocs); - for (int i = 0; i < normalizedScoreDocs.length; i++) { - int finalResIndex = resIndex; - int finalI = i; - docsToRankResults.compute(new RankDoc.RankKey(originalScoreDocs[i].doc, originalScoreDocs[i].shardIndex), (key, value) -> { - if (value == null) { - value = new LinearRankDoc( - originalScoreDocs[finalI].doc, - 0, - originalScoreDocs[finalI].shardIndex, - rankResults.size() - ); + for (int result = 0; result < rankResults.size(); result++) { + ScoreDoc[] originalScoreDocs = rankResults.get(result); + ScoreDoc[] normalizedScoreDocs = normalizers[result].normalizeScores(originalScoreDocs); + for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; scoreDocIndex++) { + int finalResult = result; + int finalScoreIndex = scoreDocIndex; + docsToRankResults.compute( + new RankDoc.RankKey(originalScoreDocs[scoreDocIndex].doc, originalScoreDocs[scoreDocIndex].shardIndex), + (key, value) -> { + if (value == null) { + value = new LinearRankDoc( + originalScoreDocs[finalScoreIndex].doc, + 0, + originalScoreDocs[finalScoreIndex].shardIndex, + rankResults.size() + ); + } + value.normalizedScores[finalResult] = normalizedScoreDocs[finalScoreIndex].score; + value.weights[finalResult] = weights[finalResult]; + value.normalizers[finalResult] = normalizers[finalResult].name(); + value.score += weights[finalResult] * normalizedScoreDocs[finalScoreIndex].score; + return value; } - value.scores[finalResIndex] = normalizedScoreDocs[finalI].score; - value.weights[finalResIndex] = wrappedRetrievers.get(finalResIndex).weight; - value.normalizers[finalResIndex] = wrappedRetrievers.get(finalResIndex).normalizer.name(); - value.score += wrappedRetrievers.get(finalResIndex).weight * normalizedScoreDocs[finalI].score; - return value; - }); + ); } } - // sort the results based on rrf score, tiebreaker based on smaller doc id + // sort the results based on the final score, tiebreaker based on smaller doc id LinearRankDoc[] sortedResults = docsToRankResults.values().toArray(LinearRankDoc[]::new); Arrays.sort(sortedResults); // trim the results if needed, otherwise each shard will always return `rank_window_size` results. @@ -171,61 +155,22 @@ public String getName() { return NAME; } - @Override - protected void doToXContent(XContentBuilder builder, Params params) throws IOException { - - } - - enum ScoreNormalizer { - IDENTITY("identity") { - @Override - public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { - // no-op - return docs; - } - }, - MINMAX("minmax") { - @Override - public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { - // create a new array to avoid changing ScoreDocs in place - ScoreDoc[] scoreDocs = new ScoreDoc[docs.length]; - // to avoid 0 scores - float epsilon = Float.MIN_NORMAL; - float min = Float.MAX_VALUE; - float max = Float.MIN_VALUE; - for (ScoreDoc rd : docs) { - if (rd.score > max) { - max = rd.score; - } - if (rd.score < min) { - min = rd.score; - } - } - for (int i = 0; i < docs.length; i++) { - float score = epsilon + ((docs[i].score - min) / (max - min)); - scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex); - } - return scoreDocs; - } - }; - - private final String name; - - ScoreNormalizer(String name) { - this.name = name; - } - - abstract ScoreDoc[] normalizeScores(ScoreDoc[] docs); - - static ScoreNormalizer find(String name) { - for (ScoreNormalizer normalizer : values()) { - if (normalizer.name.equalsIgnoreCase(name)) { - return normalizer; - } + public void doToXContent(XContentBuilder builder, Params params) throws IOException { + int index = 0; + if (innerRetrievers.isEmpty() == false) { + builder.startArray(RETRIEVERS_FIELD.getPreferredName()); + for (var entry : innerRetrievers) { + builder.startObject(); + builder.startObject(LinearRetrieverComponent.NAME); + builder.field(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName(), entry.retriever()); + builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), weights[index]); + builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index].name()); + builder.endObject(); + builder.endObject(); + index++; } - throw new IllegalArgumentException( - "Unknown normalizer [" + name + "] provided. Supported values are: " + Arrays.stream(values()).map(Enum::name).toList() - ); + builder.endArray(); } + builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); } } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java new file mode 100644 index 0000000000000..c9c363ae56521 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.retriever; + +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +import static org.elasticsearch.search.retriever.LinearRetrieverBuilder.RETRIEVERS_FIELD; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class LinearRetrieverComponent implements ToXContentObject { + + public static final ParseField RETRIEVER_FIELD = new ParseField("retriever"); + public static final ParseField WEIGHT_FIELD = new ParseField("weight"); + public static final ParseField NORMALIZER_FIELD = new ParseField("normalizer"); + + static final float DEFAULT_WEIGHT = 1f; + static final ScoreNormalizer DEFAULT_NORMALIZER = ScoreNormalizer.IDENTITY; + + public static final String NAME = "component"; + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + false, + args -> { + RetrieverBuilder base = (RetrieverBuilder) args[0]; + float weight = args[1] == null ? DEFAULT_WEIGHT : (float) args[1]; + ScoreNormalizer normalizer = args[2] == null ? DEFAULT_NORMALIZER : ScoreNormalizer.fromString((String) args[2]); + return new LinearRetrieverComponent(base, weight, normalizer); + } + ); + + static { + PARSER.declareNamedObject(constructorArg(), (p, c, n) -> { + RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, n, c); + c.trackRetrieverUsage(retrieverBuilder.getName()); + return retrieverBuilder; + }, RETRIEVER_FIELD); + PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD); + PARSER.declareString(optionalConstructorArg(), NORMALIZER_FIELD); + } + + RetrieverBuilder retriever; + float weight; + ScoreNormalizer normalizer; + + public LinearRetrieverComponent(RetrieverBuilder base, float weight, ScoreNormalizer normalizer) { + this.retriever = base; + this.weight = weight; + this.normalizer = normalizer; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(RETRIEVERS_FIELD.getPreferredName(), retriever); + return builder; + } + + public static LinearRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { + return PARSER.apply(parser, context); + } + + public enum ScoreNormalizer { + IDENTITY { + @Override + public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { + // no-op + return docs; + } + }, + MINMAX { + @Override + public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { + // create a new array to avoid changing ScoreDocs in place + ScoreDoc[] scoreDocs = new ScoreDoc[docs.length]; + // to avoid 0 scores + float epsilon = Float.MIN_NORMAL; + float min = Float.MAX_VALUE; + float max = Float.MIN_VALUE; + for (ScoreDoc rd : docs) { + if (rd.score > max) { + max = rd.score; + } + if (rd.score < min) { + min = rd.score; + } + } + for (int i = 0; i < docs.length; i++) { + float score = epsilon + ((docs[i].score - min) / (max - min)); + scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex); + } + return scoreDocs; + } + }; + + public abstract ScoreDoc[] normalizeScores(ScoreDoc[] docs); + + public static ScoreNormalizer fromString(final String normalizerName){ + for(ScoreNormalizer normalizer: values()){ + if(normalizer.name().equalsIgnoreCase(normalizerName)){ + return normalizer; + } + } + throw new IllegalArgumentException("Unknown [" + normalizerName + "] ScoreNormalizer provided."); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/WrapperRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/WrapperRetrieverBuilder.java deleted file mode 100644 index dbd5d5e16491f..0000000000000 --- a/server/src/main/java/org/elasticsearch/search/retriever/WrapperRetrieverBuilder.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.search.retriever; - -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; - -import java.io.IOException; - -import static org.elasticsearch.search.retriever.LinearRetrieverBuilder.DEFAULT_NORMALIZER; -import static org.elasticsearch.search.retriever.LinearRetrieverBuilder.DEFAULT_WEIGHT; -import static org.elasticsearch.search.retriever.LinearRetrieverBuilder.RETRIEVERS_FIELD; -import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; -import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; - -public class WrapperRetrieverBuilder implements ToXContentObject { - - public static final ParseField RETRIEVER_FIELD = new ParseField("retriever"); - public static final ParseField WEIGHT_FIELD = new ParseField("weight"); - public static final ParseField NORMALIZER_FIELD = new ParseField("normalizer"); - - public static final String NAME = "component"; - static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - NAME, - false, - args -> { - RetrieverBuilder base = (RetrieverBuilder) args[0]; - float weight = args[1] == null ? DEFAULT_WEIGHT : (float) args[1]; - LinearRetrieverBuilder.ScoreNormalizer normalizer = args[2] == null - ? DEFAULT_NORMALIZER - : LinearRetrieverBuilder.ScoreNormalizer.find((String) args[2]); - return new WrapperRetrieverBuilder(base, weight, normalizer); - } - ); - - static { - PARSER.declareNamedObject(constructorArg(), (p, c, n) -> { - RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, n, c); - c.trackRetrieverUsage(retrieverBuilder.getName()); - return retrieverBuilder; - }, RETRIEVER_FIELD); - PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD); - PARSER.declareString(optionalConstructorArg(), NORMALIZER_FIELD); - } - - RetrieverBuilder retriever; - float weight; - LinearRetrieverBuilder.ScoreNormalizer normalizer; - - public WrapperRetrieverBuilder(RetrieverBuilder base, float weight, LinearRetrieverBuilder.ScoreNormalizer normalizer) { - this.retriever = base; - this.weight = weight; - this.normalizer = normalizer; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(RETRIEVERS_FIELD.getPreferredName(), retriever); - return builder; - } - - public static WrapperRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { - return PARSER.apply(parser, context); - } -} diff --git a/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java new file mode 100644 index 0000000000000..222072dda1d57 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.retriever; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.usage.SearchUsage; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.XContentParser; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static java.util.Collections.emptyList; + +public class LinearRetrieverBuilderParsingTests extends AbstractXContentTestCase { + private static List xContentRegistryEntries; + + @BeforeClass + public static void init() { + xContentRegistryEntries = new SearchModule(Settings.EMPTY, emptyList()).getNamedXContents(); + } + + @AfterClass + public static void afterClass() throws Exception { + xContentRegistryEntries = null; + } + + @Override + protected LinearRetrieverBuilder createTestInstance() { + int rankWindowSize = randomInt(100); + int num = randomIntBetween(1, 3); + List innerRetrievers = new ArrayList<>(); + float[] weights = new float[num]; + LinearRetrieverComponent.ScoreNormalizer[] normalizers = new LinearRetrieverComponent.ScoreNormalizer[num]; + for (int i = 0; i < num; i++) { + innerRetrievers.add( + new CompoundRetrieverBuilder.RetrieverSource(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), null) + ); + weights[i] = randomFloat(); + normalizers[i] = randomFrom(LinearRetrieverComponent.ScoreNormalizer.values()); + } + return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers); + } + + @Override + protected LinearRetrieverBuilder doParseInstance(XContentParser parser) throws IOException { + return (LinearRetrieverBuilder) RetrieverBuilder.parseTopLevelRetrieverBuilder( + parser, + new RetrieverParserContext(new SearchUsage(), n -> true) + ); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List entries = new ArrayList<>(xContentRegistryEntries); + entries.add( + new NamedXContentRegistry.Entry( + RetrieverBuilder.class, + TestRetrieverBuilder.TEST_SPEC.getName(), + (p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c), + TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion() + ) + ); + return new NamedXContentRegistry(entries); + } +} From a7da4f3a75146bf7f4ec84ce9530d9f1d8f47bd5 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Mon, 13 Jan 2025 17:25:36 +0200 Subject: [PATCH 03/57] iter --- .../search.retrievers/40_linear_retriever.yml | 288 +++++++++++++++++- .../retriever/CompoundRetrieverBuilder.java | 6 +- .../retriever/LinearRetrieverBuilder.java | 1 - .../retriever/LinearRetrieverComponent.java | 6 +- 4 files changed, 292 insertions(+), 9 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml index 8a12e54221c55..a00b5f21f7628 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml @@ -215,9 +215,291 @@ setup: } ] -#--- -#"pagination within a consistent rank_window_size": -# +--- +"pagination within a consistent rank_window_size": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + } + ] + } + } + } + }, + weight: 10.0, + normalizer: "minmax" + } + }, + { + component: { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + } + ] + from: 2 + size: 1 + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._score: 2.0 } + + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + } + ] + } + } + } + }, + weight: 10.0, + normalizer: "minmax" + } + }, + { + component: { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + } + ] + from: 3 + size: 1 + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "3" } + - close_to: { hits.hits.0._score: { value: 0.0, error: 0.001 } } + +--- +"should throw when rank_window_size less than default size": + - do: + catch: "/\\[linear\\] requires \\[rank_window_size: 2\\] be greater than or equal to \\[size: 10\\]/" + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + match_all: { } + } + } + }, + weight: 10.0, + normalizer: "minmax" + } + }, + { + component: { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + } + ] + rank_window_size: 2 + +--- +"should respect rank_window_size for normalization and returned hits": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + } + ] + } + } + } + }, + weight: 1.0, + normalizer: "minmax" + } + }, + { + component: { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + } + ] + rank_window_size: 2 + size: 2 + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "4" } + - close_to: { hits.hits.0._score: { value: 2.0, error: 0.001 } } + - match: { hits.hits.1._id: "1" } + - close_to: { hits.hits.1._score: { value: 1.0, error: 0.001 } } + #--- #"explain should provide info on weights and inner retrievers": # diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 298340e5c579e..576e3c2d240f8 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -41,6 +41,7 @@ import java.util.Objects; import static org.elasticsearch.action.ValidateActions.addValidationError; +import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; /** * This abstract retriever defines a compound retriever. The idea is that it is not a leaf-retriever, i.e. it does not @@ -219,7 +220,8 @@ public ActionRequestValidationException validate( boolean allowPartialSearchResults ) { validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults); - if (source.size() > rankWindowSize) { + final int size = source.size() < 0 ? DEFAULT_SIZE : source.size(); + if (size > rankWindowSize) { validationException = addValidationError( String.format( Locale.ROOT, @@ -227,7 +229,7 @@ public ActionRequestValidationException validate( getName(), getRankWindowSizeField().getPreferredName(), rankWindowSize, - source.size() + size ), validationException ); diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java index 764181ea5c038..57c5d902daf64 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java @@ -46,7 +46,6 @@ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder Date: Mon, 13 Jan 2025 21:23:48 +0200 Subject: [PATCH 04/57] iter --- .../search.retrievers/40_linear_retriever.yml | 358 +++++++++++++++++- .../search/rank/LinearRankDoc.java | 14 +- .../retriever/LinearRetrieverBuilder.java | 7 +- 3 files changed, 347 insertions(+), 32 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml index a00b5f21f7628..35097160e9de6 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml @@ -17,6 +17,8 @@ setup: similarity: l2_norm keyword: type: keyword + other_keyword: + type: keyword - do: bulk: @@ -24,13 +26,14 @@ setup: index: test body: - '{"index": {"_id": 1 }}' - - '{"vector": [1], "keyword": "one"}' + - '{"vector": [1], "keyword": "one", "other_keyword": "other"}' - '{"index": {"_id": 2 }}' - '{"vector": [2], "keyword": "two"}' - '{"index": {"_id": 3 }}' - '{"vector": [3], "keyword": "three"}' - '{"index": {"_id": 4 }}' - - '{"vector": [4], "keyword": "four"}' + - '{"vector": [4], "keyword": "four", "other_keyword": "other"}' + --- "basic linear weighted combination of a standard and knn retrievers": - do: @@ -500,22 +503,335 @@ setup: - match: { hits.hits.1._id: "1" } - close_to: { hits.hits.1._score: { value: 1.0, error: 0.001 } } -#--- -#"explain should provide info on weights and inner retrievers": -# -#--- -#"collapsing results": -# -#--- -#"highlighting results": -# -#--- -#"multiple nested linear retrievers": -# -#--- -#"linear retriever with filters": -# -# -#--- -#"linear retriever with filters on nested retrievers": -# +--- +"explain should provide info on weights and inner retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "four" + } + } + }, + boost: 1.0 + } + } + ] + } + }, + _name: "my_standard_retriever" + } + }, + weight: 10.0, + normalizer: "minmax" + } + }, + { + component: { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 20.0 + } + } + ] + explain: true + size: 2 + + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._explanation.description: "/weighted.linear.combination.score:.\\[20.0].computed.for.normalized.scores.\\[.*,.1.0\\].and.weights.\\[10.0,.20.0\\].as.sum.of.\\(weight\\[i\\].*.score\\[i\\]\\).for.each.query./"} + - close_to: { hits.hits.0._explanation.details.0.value: { value: 0.0, error: 0.001 } } + - match: { hits.hits.0._explanation.details.0.description: "/.*weighted.score.*\\[my_standard_retriever\\].*using.score.normalizer.\\[MINMAX\\].*/" } + - close_to: { hits.hits.0._explanation.details.1.value: { value: 20.0, error: 0.001 } } + - match: { hits.hits.0._explanation.details.1.description: "/.*weighted.score.*using.score.normalizer.\\[IDENTITY\\].*/" } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1._explanation.description: "/weighted.linear.combination.score:.\\[10.0].computed.for.normalized.scores.\\[1.0,.0.0\\].and.weights.\\[10.0,.20.0\\].as.sum.of.\\(weight\\[i\\].*.score\\[i\\]\\).for.each.query./"} + - close_to: { hits.hits.1._explanation.details.0.value: { value: 10.0, error: 0.001 } } + - match: { hits.hits.1._explanation.details.0.description: "/.*weighted.score.*using.score.normalizer.\\[MINMAX\\].*/" } + - close_to: { hits.hits.1._explanation.details.1.value: { value: 0.0, error: 0.001 } } + - match: { hits.hits.1._explanation.details.1.description: "/.*weighted.score.*result.not.found.in.query.at.index.\\[1\\]/" } + +--- +"collapsing results": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 0.5 + } + }, + { + component: { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + } + ] + collapse: + field: other_keyword + inner_hits: { + name: sub_hits, + sort: + { + keyword: { + order: desc + } + } + } + - match: { hits.hits.0._id: "1" } + - length: { hits.hits.0.inner_hits.sub_hits.hits.hits : 2 } + - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.0._id: "1" } + - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.1._id: "4" } + +--- +"multiple nested linear retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 0.5 + } + }, + { + component: { + retriever: { + linear: { + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 20.0 + } + } + } + } + } + }, + { + component: + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + } + } + } + ] + } + }, + weight: 2.0 + } + } + ] + + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.0._score: 40.0 } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1._score: 5.0 } + - match: { hits.hits.2._id: "4" } + - match: { hits.hits.2._score: 2.0 } + +--- +"linear retriever with filters": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 0.5 + } + }, + { + component: { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + } + ] + filter: + term: + keyword: "four" + + + - match: { hits.total.value: 1 } + - length: {hits.hits: 1} + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._score: 2.0 } + +--- +"linear retriever with filters on nested retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + filter: { + term: { + keyword: "four" + } + } + } + }, + weight: 0.5 + } + }, + { + component: { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + } + ] + + - match: { hits.total.value: 1 } + - length: {hits.hits: 1} + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._score: 2.0 } diff --git a/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java b/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java index 6998e5d81c79c..41bf73597ea02 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java @@ -20,15 +20,15 @@ public class LinearRankDoc extends RankDoc { - public float[] weights; + private final float[] weights; + private final String[] normalizers; public float[] normalizedScores; - public String[] normalizers; - public LinearRankDoc(int doc, float score, int shardIndex, int queriesCount) { + public LinearRankDoc(int doc, float score, int shardIndex, float[] weights, String[] normalizers) { super(doc, score, shardIndex); - this.weights = new float[queriesCount]; - this.normalizedScores = new float[queriesCount]; - this.normalizers = new String[queriesCount]; + this.weights = weights; + this.normalizers = normalizers; + this.normalizedScores = new float[normalizers.length]; } public LinearRankDoc(StreamInput in) throws IOException { @@ -75,7 +75,7 @@ public Explanation explain(Explanation[] sources, String[] queryNames) { + Arrays.toString(normalizedScores) + " and weights " + Arrays.toString(weights) - + "] as sum of (weight[i] * score[i]) for each query.", + + " as sum of (weight[i] * score[i]) for each query.", details ); } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java index 57c5d902daf64..0d71c29052690 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java @@ -123,14 +123,13 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { if (value == null) { value = new LinearRankDoc( originalScoreDocs[finalScoreIndex].doc, - 0, + 0f, originalScoreDocs[finalScoreIndex].shardIndex, - rankResults.size() + weights, + Arrays.stream(normalizers).map(Enum::name).toArray(String[]::new) ); } value.normalizedScores[finalResult] = normalizedScoreDocs[finalScoreIndex].score; - value.weights[finalResult] = weights[finalResult]; - value.normalizers[finalResult] = normalizers[finalResult].name(); value.score += weights[finalResult] * normalizedScoreDocs[finalScoreIndex].score; return value; } From d64effa944c1a77a28e7fcd7b0b9f820a33261ca Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 14 Jan 2025 14:54:05 +0200 Subject: [PATCH 05/57] iter --- docs/reference/rest-api/common-parms.asciidoc | 38 +++++- docs/reference/search/retriever.asciidoc | 12 +- docs/reference/search/rrf.asciidoc | 12 +- .../retrievers-examples.asciidoc | 124 ++++++++++++++++++ .../retrievers-overview.asciidoc | 3 + 5 files changed, 173 insertions(+), 16 deletions(-) diff --git a/docs/reference/rest-api/common-parms.asciidoc b/docs/reference/rest-api/common-parms.asciidoc index 83c11c9256a67..b9abd60c4f8ad 100644 --- a/docs/reference/rest-api/common-parms.asciidoc +++ b/docs/reference/rest-api/common-parms.asciidoc @@ -1338,7 +1338,7 @@ that lower ranked documents have more influence. This value must be greater than equal to `1`. Defaults to `60`. end::rrf-rank-constant[] -tag::rrf-rank-window-size[] +tag::compound-retriever-rank-window-size[] `rank_window_size`:: (Optional, integer) + @@ -1347,12 +1347,42 @@ query. A higher value will improve result relevance at the cost of performance. ranked result set is pruned down to the search request's <>. `rank_window_size` must be greater than or equal to `size` and greater than or equal to `1`. Defaults to the `size` parameter. -end::rrf-rank-window-size[] +end::compound-retriever-rank-window-size[] -tag::rrf-filter[] +tag::compound-retriever-filter[] `filter`:: (Optional, <>) + Applies the specified <> to all of the specified sub-retrievers, according to each retriever's specifications. -end::rrf-filter[] +end::compound-retriever-filter[] + +tag::linear-retriever-components[] +`components`:: +(Required, array of `component` objects) ++ +A list of the components, i.e. the sub-retrievers' configuration, that we will take into account and whose result sets +we will merge through a weighted sum. Each component can have a different weight and normalization depending +on the specified retriever. + +Each `component` entry specifies the following parameters: + +* `retriever`:: +(Required, a <> object) ++ +Specifies the retriever for which we will compute the top documents for. The retriever will produce `rank_window_size` +results, which will later be merged based on the specified `weight` and `normalizer`. + +* `weight`:: +(Optional, float) ++ +The weight that each score of this retriever's top docs will be multiplied with. Defaults to 1.0. + +* `normalizer`:: +(Optional, string) ++ +Specifies how we will normalize the retriever's scores, before applying the specified `weight`. +Available values are: `minmax`, `identity`. Defaults to `identity`. + +See also <> using a linear retriever. +end::linear-retriever-components[] diff --git a/docs/reference/search/retriever.asciidoc b/docs/reference/search/retriever.asciidoc index 28baee4d4c7ad..d33e245f5a5ab 100644 --- a/docs/reference/search/retriever.asciidoc +++ b/docs/reference/search/retriever.asciidoc @@ -272,12 +272,12 @@ A retriever that normalizes and linearly combines the scores of other retrievers ===== Parameters -`retrievers`:: -(Required, A list of <>) +include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=linear-retriever-components] -sth +include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size] + +include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-filter] -+ [[rrf-retriever]] ==== RRF Retriever @@ -290,9 +290,9 @@ include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers] include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant] -include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-window-size] +include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size] -include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-filter] +include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-filter] [discrete] [[rrf-retriever-example-hybrid]] diff --git a/docs/reference/search/rrf.asciidoc b/docs/reference/search/rrf.asciidoc index 842bd7049e3bf..59976cec9c0aa 100644 --- a/docs/reference/search/rrf.asciidoc +++ b/docs/reference/search/rrf.asciidoc @@ -45,7 +45,7 @@ include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers] include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-constant] -include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-rank-window-size] +include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-rank-window-size] An example request using RRF: @@ -791,11 +791,11 @@ A more specific example of highlighting in RRF can also be found in the <> functionality, allowing you to retrieve -related nested or parent/child documents alongside your main search results. Inner hits can be -specified as part of any nested sub-retriever and will be propagated to the top-level parent -retriever. Note that the inner hit computation will take place only at end of `rrf` retriever's -evaluation on the top matching documents, and not as part of the query execution of the nested +The `rrf` retriever supports <> functionality, allowing you to retrieve +related nested or parent/child documents alongside your main search results. Inner hits can be +specified as part of any nested sub-retriever and will be propagated to the top-level parent +retriever. Note that the inner hit computation will take place only at end of `rrf` retriever's +evaluation on the top matching documents, and not as part of the query execution of the nested sub-retrievers. [IMPORTANT] diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index 5cada8960aeab..a8939a7d97660 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -178,6 +178,130 @@ This returns the following response based on the final rrf score for each result // TESTRESPONSE[s/"took": 42/"took": $body.took/] ============== +[discrete] +[[retrievers-examples-linear-retriever]] +==== Example: Hybrid search with linear retriever + +A different, and more intuitive, way to provide hybrid search, is to linearly combine the top documents of different +retrievers using a weighted sum of the original scores. Since, as above, the scores could lie in different ranges, +we can also specify a `normalizer` that would ensure that all scores for the top ranked documents of a retriever +lie in a specific range. + +To implement this, we define a `linear` retriever, and a set of `components` as the nested heterogeneous results sets +that we will combine. We will solve a problem similar to the above, by merging the results of a `standard` and a `knn` +retriever. As the `standard` retriever's scores are based on BM25 and are not strictly bounded, we will also define a +`minmax` normalizer to ensure that the scores lie in the [0, 1] range. We will apply the same normalizer to `knn` as well +to ensure that we capture the importance of each document within the result set. + +So, let's now specify the `linear` retriever whose final score is computed as follows: + +[source] +---- +score = weight(standard) * score(standard) + weight(knn) * score(knn) +score = 2 * score(standard) + 1.5 * score(knn) +---- +// NOTCONSOLE + +[source,console] +---- +GET /retrievers_example/_search +{ + "retriever": { + "linear": { + "retrievers": [ + { + "component": { + "retriever": { + "standard": { + "query": { + "query_string": { + "query": "(information retrieval) OR (artificial intelligence)", + "default_field": "text" + } + } + } + }, + "weight": 2, + "normalizer": "minmax" + } + }, + { + "component": { + "retriever": { + "knn": { + "field": "vector", + "query_vector": [ + 0.23, + 0.67, + 0.89 + ], + "k": 3, + "num_candidates": 5 + } + }, + "weight": 1.5, + "normalizer": "minmax" + } + } + ], + "rank_window_size": 10 + } + }, + "_source": false +} +---- +// TEST + +This returns the following response based on the final linearly weighted score for each result. + +.Example response +[%collapsible] +============== +[source,console-result] +---- +{ + "took": 42, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 3, + "relation": "eq" + }, + "max_score": -1, + "hits": [ + { + "_index": "retrievers_example", + "_id": "2", + "_score": -1 + }, + { + "_index": "retrievers_example", + "_id": "1", + "_score": -2 + }, + { + "_index": "retrievers_example", + "_id": "5", + "_score": -3 + } + ] + } +} +---- +// TESTRESPONSE[s/"took": 42/"took": $body.took/] +// TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/] +// TESTRESPONSE[s/"score": -1/"score": $body.hits.hits.0._score/] +// TESTRESPONSE[s/"score": -2/"score": $body.hits.hits.1._score/] +// TESTRESPONSE[s/"score": -3/"score": $body.hits.hits.2._score/] +============== + + [discrete] [[retrievers-examples-collapsing-retriever-results]] ==== Example: Grouping results by year with `collapse` diff --git a/docs/reference/search/search-your-data/retrievers-overview.asciidoc b/docs/reference/search/search-your-data/retrievers-overview.asciidoc index 1771b5bb0d849..1a94ae18a5c20 100644 --- a/docs/reference/search/search-your-data/retrievers-overview.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-overview.asciidoc @@ -23,6 +23,9 @@ This ensures backward compatibility as existing `_search` requests remain suppor That way you can transition to the new abstraction at your own pace without mixing syntaxes. * <>. Returns top documents from a <>, in the context of a retriever framework. +* <>. +Combines the top results from multiple sub-retrievers using a weighted sum of their scores. Allows to specify different +weights for each retriever, as well as independently normalize the scores from each result set. * <>. Combines and ranks multiple first-stage retrievers using the reciprocal rank fusion (RRF) algorithm. Allows you to combine multiple result sets with different relevance indicators into a single result set. From b945acfa9a1e0e6c3bff3ad5a5b889e9811e152d Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 14 Jan 2025 22:02:25 +0200 Subject: [PATCH 06/57] iter --- .../search.retrievers/40_linear_retriever.yml | 36 ++++--- server/src/main/java/module-info.java | 2 + .../elasticsearch/plugins/SearchPlugin.java | 42 ++++++++ .../elasticsearch/search/SearchModule.java | 21 ++++ .../normalizer/IdentityScoreNormalizer.java | 51 ++++++++++ .../normalizer/MinMaxScoreNormalizer.java | 95 +++++++++++++++++++ .../search/normalizer/ScoreNormalizer.java | 34 +++++++ .../normalizer/ScoreNormalizerParser.java | 33 +++++++ .../retriever/LinearRetrieverBuilder.java | 13 ++- .../retriever/LinearRetrieverComponent.java | 67 ++++--------- .../LinearRetrieverBuilderParsingTests.java | 10 +- 11 files changed, 337 insertions(+), 67 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/normalizer/IdentityScoreNormalizer.java create mode 100644 server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java create mode 100644 server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java create mode 100644 server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizerParser.java diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml index 35097160e9de6..ab6f1169d9188 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml @@ -143,7 +143,9 @@ setup: } }, weight: 10.0, - normalizer: "minmax" + normalizer: { + minmax: { } + } } }, { @@ -174,7 +176,7 @@ setup: --- "should throw on unknown normalizer": - do: - catch: /Unknown \[aardvark\] ScoreNormalizer provided/ + catch: /unknown field \[aardvark\]/ search: index: test body: @@ -200,7 +202,9 @@ setup: } }, weight: 1.0, - normalizer: "aardvark" + normalizer: { + aardvark: { } + } } }, { @@ -276,7 +280,9 @@ setup: } }, weight: 10.0, - normalizer: "minmax" + normalizer: { + minmax: { } + } } }, { @@ -356,7 +362,9 @@ setup: } }, weight: 10.0, - normalizer: "minmax" + normalizer: { + minmax: { } + } } }, { @@ -400,7 +408,9 @@ setup: } }, weight: 10.0, - normalizer: "minmax" + normalizer: { + minmax: { } + } } }, { @@ -477,7 +487,9 @@ setup: } }, weight: 1.0, - normalizer: "minmax" + normalizer: { + minmax: { } + } } }, { @@ -550,7 +562,9 @@ setup: } }, weight: 10.0, - normalizer: "minmax" + normalizer: { + minmax: { } + } } }, { @@ -573,13 +587,13 @@ setup: - match: { hits.hits.0._id: "4" } - match: { hits.hits.0._explanation.description: "/weighted.linear.combination.score:.\\[20.0].computed.for.normalized.scores.\\[.*,.1.0\\].and.weights.\\[10.0,.20.0\\].as.sum.of.\\(weight\\[i\\].*.score\\[i\\]\\).for.each.query./"} - close_to: { hits.hits.0._explanation.details.0.value: { value: 0.0, error: 0.001 } } - - match: { hits.hits.0._explanation.details.0.description: "/.*weighted.score.*\\[my_standard_retriever\\].*using.score.normalizer.\\[MINMAX\\].*/" } + - match: { hits.hits.0._explanation.details.0.description: "/.*weighted.score.*\\[my_standard_retriever\\].*using.score.normalizer.\\[minmax\\].*/" } - close_to: { hits.hits.0._explanation.details.1.value: { value: 20.0, error: 0.001 } } - - match: { hits.hits.0._explanation.details.1.description: "/.*weighted.score.*using.score.normalizer.\\[IDENTITY\\].*/" } + - match: { hits.hits.0._explanation.details.1.description: "/.*weighted.score.*using.score.normalizer.\\[none\\].*/" } - match: { hits.hits.1._id: "1" } - match: { hits.hits.1._explanation.description: "/weighted.linear.combination.score:.\\[10.0].computed.for.normalized.scores.\\[1.0,.0.0\\].and.weights.\\[10.0,.20.0\\].as.sum.of.\\(weight\\[i\\].*.score\\[i\\]\\).for.each.query./"} - close_to: { hits.hits.1._explanation.details.0.value: { value: 10.0, error: 0.001 } } - - match: { hits.hits.1._explanation.details.0.description: "/.*weighted.score.*using.score.normalizer.\\[MINMAX\\].*/" } + - match: { hits.hits.1._explanation.details.0.description: "/.*weighted.score.*using.score.normalizer.\\[minmax\\].*/" } - close_to: { hits.hits.1._explanation.details.1.value: { value: 0.0, error: 0.001 } } - match: { hits.hits.1._explanation.details.1.description: "/.*weighted.score.*result.not.found.in.query.at.index.\\[1\\]/" } diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 4112290fa4e04..8fed30a348ba7 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -55,6 +55,7 @@ requires org.apache.lucene.queryparser; requires org.apache.lucene.sandbox; requires org.apache.lucene.suggest; + requires java.desktop; exports org.elasticsearch; exports org.elasticsearch.action; @@ -355,6 +356,7 @@ exports org.elasticsearch.search.fetch.subphase.highlight; exports org.elasticsearch.search.internal; exports org.elasticsearch.search.lookup; + exports org.elasticsearch.search.normalizer; exports org.elasticsearch.search.profile; exports org.elasticsearch.search.profile.aggregation; exports org.elasticsearch.search.profile.dfs; diff --git a/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java b/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java index e87e9ee85b29c..ac706dab459ea 100644 --- a/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java +++ b/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java @@ -37,6 +37,8 @@ import org.elasticsearch.search.fetch.FetchSubPhase; import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.search.normalizer.ScoreNormalizer; +import org.elasticsearch.search.normalizer.ScoreNormalizerParser; import org.elasticsearch.search.rescore.Rescorer; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; @@ -316,6 +318,46 @@ public RetrieverParser getParser() { } } + /** + * Specification of custom {@link ScoreNormalizer}. + */ + class ScoreNormalizerSpec { + + private final ParseField name; + private final ScoreNormalizerParser parser; + + /** + * Specification of custom {@link ScoreNormalizer}. + * + * @param name holds the names by which this score normalizer might be parsed. The {@link ParseField#getPreferredName()} + * is special as it is the name by under which the reader is registered. So it is the name that the normalizer + * should use as its {@link NamedWriteable#getWriteableName()} too. + * @param parser the parser the reads the retriever builder from xcontent + */ + public ScoreNormalizerSpec(ParseField name, ScoreNormalizerParser parser) { + this.name = name; + this.parser = parser; + } + + /** + * Specification of custom {@link ScoreNormalizer}. + * + * @param name the name by which this normalizer might be parsed or deserialized + * @param parser the parser the reads the {@code ScoreNormalizer} from xcontent + */ + public ScoreNormalizerSpec(String name, ScoreNormalizerParser parser) { + this(new ParseField(name), parser); + } + + public ParseField getName() { + return name; + } + + public ScoreNormalizerParser getParser() { + return parser; + } + } + /** * Specification of custom {@link Query}. */ diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index ead4df3ac2c4f..38867205041df 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -87,6 +87,7 @@ import org.elasticsearch.plugins.SearchPlugin.RescorerSpec; import org.elasticsearch.plugins.SearchPlugin.RetrieverSpec; import org.elasticsearch.plugins.SearchPlugin.ScoreFunctionSpec; +import org.elasticsearch.plugins.SearchPlugin.ScoreNormalizerSpec; import org.elasticsearch.plugins.SearchPlugin.SearchExtSpec; import org.elasticsearch.plugins.SearchPlugin.SignificanceHeuristicSpec; import org.elasticsearch.plugins.SearchPlugin.SuggesterSpec; @@ -224,6 +225,9 @@ import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.search.fetch.subphase.highlight.PlainHighlighter; import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.search.normalizer.IdentityScoreNormalizer; +import org.elasticsearch.search.normalizer.MinMaxScoreNormalizer; +import org.elasticsearch.search.normalizer.ScoreNormalizer; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankShardResult; import org.elasticsearch.search.rank.feature.RankFeatureDoc; @@ -351,6 +355,7 @@ public SearchModule(Settings settings, List plugins, TelemetryProv highlighters = setupHighlighters(settings, plugins); registerScoreFunctions(plugins); registerRetrieverParsers(plugins); + registerScoreNormalizerParsers(plugins); registerQueryParsers(plugins); registerRescorers(plugins); registerRankers(); @@ -1089,6 +1094,11 @@ private void registerRetrieverParsers(List plugins) { registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever); } + private void registerScoreNormalizerParsers(List plugins) { + registerScoreNormalizer(new ScoreNormalizerSpec<>(MinMaxScoreNormalizer.NAME, MinMaxScoreNormalizer::fromXContent)); + registerScoreNormalizer(new ScoreNormalizerSpec<>(IdentityScoreNormalizer.NAME, IdentityScoreNormalizer::fromXContent)); + } + private void registerQueryParsers(List plugins) { registerQuery(new QuerySpec<>(MatchQueryBuilder.NAME, MatchQueryBuilder::new, MatchQueryBuilder::fromXContent)); registerQuery(new QuerySpec<>(MatchPhraseQueryBuilder.NAME, MatchPhraseQueryBuilder::new, MatchPhraseQueryBuilder::fromXContent)); @@ -1265,6 +1275,17 @@ private void registerRetriever(RetrieverSpec spec) { ); } + private void registerScoreNormalizer(ScoreNormalizerSpec spec) { + namedXContents.add( + new NamedXContentRegistry.Entry( + ScoreNormalizer.class, + spec.getName(), + (p, c) -> spec.getParser().fromXContent(p), + spec.getName().getForRestApiVersion() + ) + ); + } + private void registerQuery(QuerySpec spec) { namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, spec.getName().getPreferredName(), spec.getReader())); namedXContents.add( diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/IdentityScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/normalizer/IdentityScoreNormalizer.java new file mode 100644 index 0000000000000..234ccec90ddff --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/normalizer/IdentityScoreNormalizer.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.normalizer; + +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +public class IdentityScoreNormalizer extends ScoreNormalizer { + + public static final IdentityScoreNormalizer INSTANCE = new IdentityScoreNormalizer(); + + public static final String NAME = "none"; + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { + if (args.length != 0) { + throw new IllegalArgumentException("[IdentityScoreNormalizer] does not accept any arguments"); + } + return new IdentityScoreNormalizer(); + }); + + @Override + public String getName() { + return NAME; + } + + @Override + public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { + return docs; + } + + public static IdentityScoreNormalizer fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // no-op + return builder; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java new file mode 100644 index 0000000000000..94f694a16a262 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.normalizer; + +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class MinMaxScoreNormalizer extends ScoreNormalizer { + + public static final String NAME = "minmax"; + + public static final ParseField MIN_FIELD = new ParseField("min"); + public static final ParseField MAX_FIELD = new ParseField("max"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { + Float min = (Float) args[0]; + Float max = (Float) args[1]; + return new MinMaxScoreNormalizer(min, max); + }); + + static { + PARSER.declareFloat(optionalConstructorArg(), MIN_FIELD); + PARSER.declareFloat(optionalConstructorArg(), MAX_FIELD); + } + + private Float min; + private Float max; + + public MinMaxScoreNormalizer() {} + + public MinMaxScoreNormalizer(Float min, Float max) { + this.min = min; + this.max = max; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { + // create a new array to avoid changing ScoreDocs in place + ScoreDoc[] scoreDocs = new ScoreDoc[docs.length]; + if (min == null || max == null) { + float localMin = Float.MAX_VALUE; + float localMax = Float.MIN_VALUE; + for (ScoreDoc rd : docs) { + if (max == null && rd.score > localMax) { + localMax = rd.score; + } + if (min == null && rd.score < localMin) { + localMin = rd.score; + } + } + if (min == null) { + min = localMin; + } + if (max == null) { + max = localMax; + } + } + float epsilon = Float.MIN_NORMAL; + for (int i = 0; i < docs.length; i++) { + float score = epsilon + ((docs[i].score - min) / (max - min)); + scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex); + } + return scoreDocs; + } + + public static MinMaxScoreNormalizer fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MIN_FIELD.getPreferredName(), min); + builder.field(MAX_FIELD.getPreferredName(), max); + return builder; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java new file mode 100644 index 0000000000000..643b896050eb7 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.normalizer; + +import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.xcontent.ToXContent; + +/** + * A no-op {@link ScoreNormalizer} that does not modify the scores. + */ +public abstract class ScoreNormalizer implements ToXContent { + + public static ScoreNormalizer valueOf(String normalizer) { + if (MinMaxScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) { + return new MinMaxScoreNormalizer(); + } else if (IdentityScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) { + return new IdentityScoreNormalizer(); + + } else { + throw new IllegalArgumentException("Unknown normalizer [" + normalizer + "]"); + } + } + + public abstract String getName(); + + public abstract ScoreDoc[] normalizeScores(ScoreDoc[] docs); +} diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizerParser.java b/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizerParser.java new file mode 100644 index 0000000000000..f439fef5f5140 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizerParser.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.normalizer; + +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +/** + * Defines a ScoreNormalizer parser that is able to parse {@link ScoreNormalizer}s + * from {@link org.elasticsearch.xcontent.XContent}. + */ +@FunctionalInterface +public interface ScoreNormalizerParser { + + /** + * Creates a new {@link RetrieverBuilder} from the retriever held by the + * {@link XContentParser}. The state on the parser contained in this context + * will be changed as a side effect of this method call. The + * {@link RetrieverParserContext} tracks usage of retriever features and + * queries when available. + */ + SN fromXContent(XContentParser parser) throws IOException; +} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java index 0d71c29052690..da58dd953c1ef 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.normalizer.ScoreNormalizer; import org.elasticsearch.search.rank.LinearRankDoc; import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankDoc; @@ -44,11 +45,10 @@ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder PARSER = new ConstructingObjectParser<>( @@ -59,8 +59,7 @@ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder innerRetrievers = new ArrayList<>(); float[] weights = new float[retrieverComponents.size()]; - LinearRetrieverComponent.ScoreNormalizer[] normalizers = new LinearRetrieverComponent.ScoreNormalizer[retrieverComponents - .size()]; + ScoreNormalizer[] normalizers = new ScoreNormalizer[retrieverComponents.size()]; int index = 0; for (LinearRetrieverComponent component : retrieverComponents) { innerRetrievers.add(new RetrieverSource(component.retriever, null)); @@ -94,7 +93,7 @@ public LinearRetrieverBuilder( List innerRetrievers, int rankWindowSize, float[] weights, - LinearRetrieverComponent.ScoreNormalizer[] normalizers + ScoreNormalizer[] normalizers ) { super(innerRetrievers, rankWindowSize); this.weights = weights; @@ -126,7 +125,7 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { 0f, originalScoreDocs[finalScoreIndex].shardIndex, weights, - Arrays.stream(normalizers).map(Enum::name).toArray(String[]::new) + Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new) ); } value.normalizedScores[finalResult] = normalizedScoreDocs[finalScoreIndex].score; @@ -162,7 +161,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept builder.startObject(LinearRetrieverComponent.NAME); builder.field(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName(), entry.retriever()); builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), weights[index]); - builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index].name()); + builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index]); builder.endObject(); builder.endObject(); index++; diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java index 5b39f85c51821..9c3b0b32cf9af 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java @@ -9,8 +9,10 @@ package org.elasticsearch.search.retriever; -import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.search.normalizer.IdentityScoreNormalizer; +import org.elasticsearch.search.normalizer.ScoreNormalizer; import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -29,7 +31,7 @@ public class LinearRetrieverComponent implements ToXContentObject { public static final ParseField NORMALIZER_FIELD = new ParseField("normalizer"); static final float DEFAULT_WEIGHT = 1f; - static final ScoreNormalizer DEFAULT_NORMALIZER = ScoreNormalizer.IDENTITY; + static final ScoreNormalizer DEFAULT_NORMALIZER = IdentityScoreNormalizer.INSTANCE; public static final String NAME = "component"; @@ -39,7 +41,7 @@ public class LinearRetrieverComponent implements ToXContentObject { args -> { RetrieverBuilder base = (RetrieverBuilder) args[0]; float weight = args[1] == null ? DEFAULT_WEIGHT : (float) args[1]; - ScoreNormalizer normalizer = args[2] == null ? DEFAULT_NORMALIZER : ScoreNormalizer.fromString((String) args[2]); + ScoreNormalizer normalizer = args[2] == null ? DEFAULT_NORMALIZER : (ScoreNormalizer) args[2]; return new LinearRetrieverComponent(base, weight, normalizer); } ); @@ -51,7 +53,19 @@ public class LinearRetrieverComponent implements ToXContentObject { return retrieverBuilder; }, RETRIEVER_FIELD); PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD); - PARSER.declareString(optionalConstructorArg(), NORMALIZER_FIELD); + + PARSER.declareField(optionalConstructorArg(), (p, c) -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + final String normalizer = p.text(); + return ScoreNormalizer.valueOf(normalizer); + } else if (p.currentToken() == XContentParser.Token.START_OBJECT) { + p.nextToken(); + ScoreNormalizer normalizer = p.namedObject(ScoreNormalizer.class, p.currentName(), c); + p.nextToken(); + return normalizer; + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, NORMALIZER_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING); } RetrieverBuilder retriever; @@ -74,48 +88,7 @@ public static LinearRetrieverComponent fromXContent(XContentParser parser, Retri return PARSER.apply(parser, context); } - public enum ScoreNormalizer { - IDENTITY { - @Override - public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { - // no-op - return docs; - } - }, - MINMAX { - @Override - public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { - // create a new array to avoid changing ScoreDocs in place - ScoreDoc[] scoreDocs = new ScoreDoc[docs.length]; - // to avoid 0 scores - float epsilon = Float.MIN_NORMAL; - float min = Float.MAX_VALUE; - float max = Float.MIN_VALUE; - for (ScoreDoc rd : docs) { - if (rd.score > max) { - max = rd.score; - } - if (rd.score < min) { - min = rd.score; - } - } - for (int i = 0; i < docs.length; i++) { - float score = epsilon + ((docs[i].score - min) / (max - min)); - scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex); - } - return scoreDocs; - } - }; - - public abstract ScoreDoc[] normalizeScores(ScoreDoc[] docs); - - public static ScoreNormalizer fromString(final String normalizerName) { - for (ScoreNormalizer normalizer : values()) { - if (normalizer.name().equalsIgnoreCase(normalizerName)) { - return normalizer; - } - } - throw new IllegalArgumentException("Unknown [" + normalizerName + "] ScoreNormalizer provided."); - } + private void setNormalizerField(ScoreNormalizer normalizer) { + this.normalizer = normalizer; } } diff --git a/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java index 222072dda1d57..1d6d382726a07 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java @@ -11,6 +11,8 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.normalizer.MinMaxScoreNormalizer; +import org.elasticsearch.search.normalizer.ScoreNormalizer; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -43,13 +45,13 @@ protected LinearRetrieverBuilder createTestInstance() { int num = randomIntBetween(1, 3); List innerRetrievers = new ArrayList<>(); float[] weights = new float[num]; - LinearRetrieverComponent.ScoreNormalizer[] normalizers = new LinearRetrieverComponent.ScoreNormalizer[num]; + ScoreNormalizer[] normalizers = new ScoreNormalizer[num]; for (int i = 0; i < num; i++) { innerRetrievers.add( new CompoundRetrieverBuilder.RetrieverSource(TestRetrieverBuilder.createRandomTestRetrieverBuilder(), null) ); weights[i] = randomFloat(); - normalizers[i] = randomFrom(LinearRetrieverComponent.ScoreNormalizer.values()); + normalizers[i] = randomScoreNormalizer(); } return new LinearRetrieverBuilder(innerRetrievers, rankWindowSize, weights, normalizers); } @@ -80,4 +82,8 @@ protected NamedXContentRegistry xContentRegistry() { ); return new NamedXContentRegistry(entries); } + + private static ScoreNormalizer randomScoreNormalizer() { + return new MinMaxScoreNormalizer(1f, 10f); + } } From 0c1b235b92309fc3ce7f30563b892558d183a58f Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 14 Jan 2025 22:28:56 +0200 Subject: [PATCH 07/57] iter --- .../org/elasticsearch/TransportVersions.java | 1 + .../elasticsearch/search/SearchModule.java | 2 + .../search/rank/LinearRankDoc.java | 23 ++++- .../search/rank/LinearRankDocTests.java | 93 +++++++++++++++++++ 4 files changed, 115 insertions(+), 4 deletions(-) create mode 100644 server/src/test/java/org/elasticsearch/search/rank/LinearRankDocTests.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index f0f3d27c6e86c..f63539d074d9d 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -156,6 +156,7 @@ static TransportVersion def(int id) { public static final TransportVersion REPLACE_FAILURE_STORE_OPTIONS_WITH_SELECTOR_SYNTAX = def(8_821_00_0); public static final TransportVersion ELASTIC_INFERENCE_SERVICE_UNIFIED_CHAT_COMPLETIONS_INTEGRATION = def(8_822_00_0); public static final TransportVersion KQL_QUERY_TECH_PREVIEW = def(8_823_00_0); + public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_824_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index 38867205041df..ff5b062e1cd5b 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -228,6 +228,7 @@ import org.elasticsearch.search.normalizer.IdentityScoreNormalizer; import org.elasticsearch.search.normalizer.MinMaxScoreNormalizer; import org.elasticsearch.search.normalizer.ScoreNormalizer; +import org.elasticsearch.search.rank.LinearRankDoc; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankShardResult; import org.elasticsearch.search.rank.feature.RankFeatureDoc; @@ -840,6 +841,7 @@ private void registerRescorer(RescorerSpec spec) { private void registerRankers() { namedWriteables.add(new NamedWriteableRegistry.Entry(RankDoc.class, RankDoc.NAME, RankDoc::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(RankDoc.class, RankFeatureDoc.NAME, RankFeatureDoc::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(RankDoc.class, LinearRankDoc.NAME, LinearRankDoc::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(RankShardResult.class, RankFeatureShardResult.NAME, RankFeatureShardResult::new) ); diff --git a/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java b/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java index 41bf73597ea02..7b4ca79b1dc48 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java @@ -10,6 +10,8 @@ package org.elasticsearch.search.rank; import org.apache.lucene.search.Explanation; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.XContentBuilder; @@ -20,8 +22,10 @@ public class LinearRankDoc extends RankDoc { - private final float[] weights; - private final String[] normalizers; + public static final String NAME = "linear_rank_doc"; + + final float[] weights; + final String[] normalizers; public float[] normalizedScores; public LinearRankDoc(int doc, float score, int shardIndex, float[] weights, String[] normalizers) { @@ -29,10 +33,11 @@ public LinearRankDoc(int doc, float score, int shardIndex, float[] weights, Stri this.weights = weights; this.normalizers = normalizers; this.normalizedScores = new float[normalizers.length]; + Arrays.fill(normalizedScores, 0f); } public LinearRankDoc(StreamInput in) throws IOException { - super(in.readVInt(), in.readFloat(), in.readVInt()); + super(in); weights = in.readFloatArray(); normalizedScores = in.readFloatArray(); normalizers = in.readStringArray(); @@ -104,7 +109,17 @@ public boolean doEquals(RankDoc rd) { @Override public int doHashCode() { - int result = Objects.hash(Arrays.hashCode(weights), Arrays.hashCode(normalizedScores) + Arrays.hashCode(normalizers)); + int result = Objects.hash(Arrays.hashCode(weights), Arrays.hashCode(normalizedScores), Arrays.hashCode(normalizers)); return 31 * result; } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.LINEAR_RETRIEVER_SUPPORT; + } } diff --git a/server/src/test/java/org/elasticsearch/search/rank/LinearRankDocTests.java b/server/src/test/java/org/elasticsearch/search/rank/LinearRankDocTests.java new file mode 100644 index 0000000000000..cbc3ef551139d --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/rank/LinearRankDocTests.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.rank; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +public class LinearRankDocTests extends AbstractRankDocWireSerializingTestCase { + + protected LinearRankDoc createTestRankDoc() { + int queries = randomIntBetween(2, 20); + float[] weights = new float[queries]; + String[] normalizers = new String[queries]; + float[] normalizedScores = new float[queries]; + for (int i = 0; i < queries; i++) { + weights[i] = randomFloat(); + normalizers[i] = randomAlphaOfLengthBetween(1, 10); + normalizedScores[i] = randomFloat(); + } + LinearRankDoc rankDoc = new LinearRankDoc(randomNonNegativeInt(), randomFloat(), randomIntBetween(0, 1), weights, normalizers); + rankDoc.rank = randomNonNegativeInt(); + rankDoc.normalizedScores = normalizedScores; + return rankDoc; + } + + @Override + protected List getAdditionalNamedWriteables() { + return Collections.emptyList(); + } + + @Override + protected Writeable.Reader instanceReader() { + return LinearRankDoc::new; + } + + @Override + protected LinearRankDoc mutateInstance(LinearRankDoc instance) throws IOException { + LinearRankDoc mutated = new LinearRankDoc( + instance.doc, + instance.score, + instance.shardIndex, + instance.weights, + instance.normalizers + ); + mutated.normalizedScores = instance.normalizedScores; + mutated.rank = instance.rank; + if (frequently()) { + mutated.doc = randomNonNegativeInt(); + } + if (frequently()) { + mutated.score = randomFloat(); + } + if (frequently()) { + mutated.shardIndex = randomNonNegativeInt(); + } + if (frequently()) { + mutated.rank = randomNonNegativeInt(); + } + if (frequently()) { + for (int i = 0; i < mutated.normalizedScores.length; i++) { + if (frequently()) { + mutated.normalizedScores[i] = randomFloat(); + } + } + } + if (frequently()) { + for (int i = 0; i < mutated.weights.length; i++) { + if (frequently()) { + mutated.weights[i] = randomFloat(); + } + } + } + if (frequently()) { + for (int i = 0; i < mutated.normalizers.length; i++) { + if (frequently()) { + mutated.normalizers[i] = randomAlphaOfLengthBetween(1, 10); + } + } + } + return mutated; + } +} From c97d27b5e01299f93261574d52d5f8220f4ba269 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 15 Jan 2025 09:01:09 +0200 Subject: [PATCH 08/57] iter --- docs/reference/rest-api/common-parms.asciidoc | 5 +++-- .../search/search-your-data/retrievers-examples.asciidoc | 7 ++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/reference/rest-api/common-parms.asciidoc b/docs/reference/rest-api/common-parms.asciidoc index b9abd60c4f8ad..6ef611c76f8c9 100644 --- a/docs/reference/rest-api/common-parms.asciidoc +++ b/docs/reference/rest-api/common-parms.asciidoc @@ -1379,10 +1379,11 @@ results, which will later be merged based on the specified `weight` and `normali The weight that each score of this retriever's top docs will be multiplied with. Defaults to 1.0. * `normalizer`:: -(Optional, string) +(Optional, String or Object) + Specifies how we will normalize the retriever's scores, before applying the specified `weight`. Available values are: `minmax`, `identity`. Defaults to `identity`. -See also <> using a linear retriever. +See also <> using a linear retriever on how to +independently configure and apply normalizers to retrievers. end::linear-retriever-components[] diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index a8939a7d97660..341100095a670 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -240,7 +240,12 @@ GET /retrievers_example/_search } }, "weight": 1.5, - "normalizer": "minmax" + "normalizer": + "minmax": { + "min": 0.5, + "max": 1.0 + } + } } } ], From 2d784042f4b7687f5f3501d87d9be9ccf6eba653 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 15 Jan 2025 19:31:49 +0200 Subject: [PATCH 09/57] iter --- docs/reference/rest-api/common-parms.asciidoc | 32 +- docs/reference/search/retriever.asciidoc | 3 +- .../search.retrievers/40_linear_retriever.yml | 305 ++++++++++++++++-- .../normalizer/MinMaxScoreNormalizer.java | 64 +++- .../retriever/CompoundRetrieverBuilder.java | 1 + .../retriever/LinearRetrieverBuilder.java | 7 +- .../retriever/LinearRetrieverComponent.java | 5 - .../LinearRetrieverBuilderParsingTests.java | 2 +- 8 files changed, 363 insertions(+), 56 deletions(-) diff --git a/docs/reference/rest-api/common-parms.asciidoc b/docs/reference/rest-api/common-parms.asciidoc index 6ef611c76f8c9..443dd4aa3ae2e 100644 --- a/docs/reference/rest-api/common-parms.asciidoc +++ b/docs/reference/rest-api/common-parms.asciidoc @@ -1382,8 +1382,36 @@ The weight that each score of this retriever's top docs will be multiplied with. (Optional, String or Object) + Specifies how we will normalize the retriever's scores, before applying the specified `weight`. -Available values are: `minmax`, `identity`. Defaults to `identity`. +We can either provide a string reference to use with the default values or further configure any normalizer +using its specific properties. Available values are: `minmax`, `none`. Defaults to `none`. -See also <> using a linear retriever on how to +** `none` : takes no argument +** `minmax` : +A `MinMaxScoreNormalizer` that normalizes scores based on the following formula ++ +``` +score = (score - min) / (max - min) * (upper_bound - lower_bound) + lower_bound +``` +Available properties are: +*** `min`:: +(Optional, float) ++ +The minimum value of the original scores. Defaults to result set's true min value. + +*** `max`:: +(Optional, float) ++ +The maximum value of the original scores. Defaults to result set's true max value. + +*** `lower_bound`:: +(Optional, float) ++ +The minimum value that the retriever's normalized scores can take. Defaults to 0. +*** `upper_bound`:: +(Optional, float) ++ +The maximum value that the retriever's normalized scores can take. Defaults to 1. + +See also <> using a linear retriever on how to independently configure and apply normalizers to retrievers. end::linear-retriever-components[] diff --git a/docs/reference/search/retriever.asciidoc b/docs/reference/search/retriever.asciidoc index d33e245f5a5ab..1132d69360a29 100644 --- a/docs/reference/search/retriever.asciidoc +++ b/docs/reference/search/retriever.asciidoc @@ -268,7 +268,8 @@ This value must be fewer than or equal to `num_candidates`. [[linear-retriever]] ==== Linear Retriever -A retriever that normalizes and linearly combines the scores of other retrievers. +A retriever that normalizes and linearly combines the scores of other retrievers. If the final scores produced after the +weighted combination of all sub-retrievers are negative, they are set to increments of `1e-6` to avoid negative scores. ===== Parameters diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml index ab6f1169d9188..15f620955190e 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml @@ -19,6 +19,8 @@ setup: type: keyword other_keyword: type: keyword + timestamp: + type: date - do: bulk: @@ -26,13 +28,13 @@ setup: index: test body: - '{"index": {"_id": 1 }}' - - '{"vector": [1], "keyword": "one", "other_keyword": "other"}' + - '{"vector": [1], "keyword": "one", "other_keyword": "other", "timestamp": "2021-01-01T00:00:00"}' - '{"index": {"_id": 2 }}' - - '{"vector": [2], "keyword": "two"}' + - '{"vector": [2], "keyword": "two", "timestamp": "2022-01-01T00:00:00"}' - '{"index": {"_id": 3 }}' - - '{"vector": [3], "keyword": "three"}' + - '{"vector": [3], "keyword": "three", "timestamp": "2023-01-01T00:00:00"}' - '{"index": {"_id": 4 }}' - - '{"vector": [4], "keyword": "four", "other_keyword": "other"}' + - '{"vector": [4], "keyword": "four", "other_keyword": "other", "timestamp": "2024-01-01T00:00:00"}' --- "basic linear weighted combination of a standard and knn retrievers": @@ -143,9 +145,7 @@ setup: } }, weight: 10.0, - normalizer: { - minmax: { } - } + normalizer: "minmax" } }, { @@ -165,14 +165,119 @@ setup: - match: { hits.total.value: 4 } - match: { hits.hits.0._id: "1" } - - close_to: { hits.hits.0._score: { value: 10.0, error: 0.001 } } + - match: {hits.hits.0._score: 10.0} - match: { hits.hits.1._id: "2" } - - close_to: { hits.hits.1._score: { value: 8.0, error: 0.001 } } + - match: {hits.hits.1._score: 8.0} - match: { hits.hits.2._id: "4" } + - match: {hits.hits.2._score: 2.0} - match: { hits.hits.2._score: 2.0 } - match: { hits.hits.3._id: "3" } - close_to: { hits.hits.3._score: { value: 0.0, error: 0.001 } } +--- +"should normalize initial scores with a custom minmax normalizer": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 # normalized score for this would be -0.55 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 # normalized score for this would be -0.56 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 # normalized score for this would be -0.63 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "four" + } + } + }, + boost: 1.0 # normalized score for this would be -0.7 + } + } + ] + } + } + } + }, + weight: 10.0, + normalizer: { + minmax: { + min: 42, + max: 100 + } + } + } + }, + { + # this only provides a score of 10 for doc 4 + component: { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 10.0 + } + } + ] + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "4" } + - close_to: { hits.hits.0._score: { value: 2.93103, error: 0.001 } } + - match: { hits.hits.1._id: "1" } + - close_to: { hits.hits.1._score: { value: 0.000003, error: 0.001 } } + - match: { hits.hits.2._id: "2" } + - close_to: { hits.hits.2._score: { value: 0.000002, error: 0.001 } } + - match: { hits.hits.3._id: "3" } + - close_to: { hits.hits.3._score: { value: 0.000001, error: 0.001 } } + --- "should throw on unknown normalizer": - do: @@ -280,9 +385,7 @@ setup: } }, weight: 10.0, - normalizer: { - minmax: { } - } + normalizer: "minmax" } }, { @@ -303,6 +406,7 @@ setup: size: 1 - match: { hits.total.value: 4 } + - length: { hits.hits: 1 } - match: { hits.hits.0._id: "4" } - match: { hits.hits.0._score: 2.0 } @@ -362,9 +466,7 @@ setup: } }, weight: 10.0, - normalizer: { - minmax: { } - } + normalizer: "minmax" } }, { @@ -408,9 +510,7 @@ setup: } }, weight: 10.0, - normalizer: { - minmax: { } - } + normalizer: "minmax" } }, { @@ -487,9 +587,7 @@ setup: } }, weight: 1.0, - normalizer: { - minmax: { } - } + normalizer: "minmax" } }, { @@ -511,9 +609,9 @@ setup: - match: { hits.total.value: 4 } - match: { hits.hits.0._id: "4" } - - close_to: { hits.hits.0._score: { value: 2.0, error: 0.001 } } + - match: { hits.hits.0._score: 2.0 } - match: { hits.hits.1._id: "1" } - - close_to: { hits.hits.1._score: { value: 1.0, error: 0.001 } } + - match: { hits.hits.1._score: 1.0 } --- "explain should provide info on weights and inner retrievers": @@ -562,9 +660,7 @@ setup: } }, weight: 10.0, - normalizer: { - minmax: { } - } + normalizer: "minmax" } }, { @@ -586,15 +682,15 @@ setup: - match: { hits.hits.0._id: "4" } - match: { hits.hits.0._explanation.description: "/weighted.linear.combination.score:.\\[20.0].computed.for.normalized.scores.\\[.*,.1.0\\].and.weights.\\[10.0,.20.0\\].as.sum.of.\\(weight\\[i\\].*.score\\[i\\]\\).for.each.query./"} - - close_to: { hits.hits.0._explanation.details.0.value: { value: 0.0, error: 0.001 } } - - match: { hits.hits.0._explanation.details.0.description: "/.*weighted.score.*\\[my_standard_retriever\\].*using.score.normalizer.\\[minmax\\].*/" } - - close_to: { hits.hits.0._explanation.details.1.value: { value: 20.0, error: 0.001 } } + - match: { hits.hits.0._explanation.details.0.value: 0.0 } + - match: { hits.hits.0._explanation.details.0.description: "/.*weighted.score.*result.not.found.in.query.at.index.\\[0\\].\\[my_standard_retriever\\]/" } + - match: { hits.hits.0._explanation.details.1.value: 20.0 } - match: { hits.hits.0._explanation.details.1.description: "/.*weighted.score.*using.score.normalizer.\\[none\\].*/" } - match: { hits.hits.1._id: "1" } - match: { hits.hits.1._explanation.description: "/weighted.linear.combination.score:.\\[10.0].computed.for.normalized.scores.\\[1.0,.0.0\\].and.weights.\\[10.0,.20.0\\].as.sum.of.\\(weight\\[i\\].*.score\\[i\\]\\).for.each.query./"} - - close_to: { hits.hits.1._explanation.details.0.value: { value: 10.0, error: 0.001 } } - - match: { hits.hits.1._explanation.details.0.description: "/.*weighted.score.*using.score.normalizer.\\[minmax\\].*/" } - - close_to: { hits.hits.1._explanation.details.1.value: { value: 0.0, error: 0.001 } } + - match: { hits.hits.1._explanation.details.0.value: 10.0 } + - match: { hits.hits.1._explanation.details.0.description: "/.*weighted.score.*\\[my_standard_retriever\\].*using.score.normalizer.\\[minmax\\].*/" } + - match: { hits.hits.1._explanation.details.1.value: 0.0 } - match: { hits.hits.1._explanation.details.1.description: "/.*weighted.score.*result.not.found.in.query.at.index.\\[1\\]/" } --- @@ -849,3 +945,148 @@ setup: - length: {hits.hits: 1} - match: { hits.hits.0._id: "4" } - match: { hits.hits.0._score: 2.0 } + + +--- +"linear retriever with custom sort and score for nested retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + component: { + retriever: { + standard: { + query: { + constant_score: { + filter: { + bool: { + should: [ + { + term: { + keyword: { + value: "one" # this will give doc 1 a normalized score of 0.5 + } + } + }, + { + term: { + keyword: { + value: "two" # this will give doc 2 a normalized score of 0.5 + } + } + } ] + } + }, + boost: 10.0 + } + }, + sort: { + timestamp: { + order: "asc" + } + } + } + }, + weight: 1.0, + normalizer: "minmax" + } + }, + { + # because we're sorting on timestamp and use a rank window size of 2, we will only get to see + # docs 3 and 2. + # their `scores` (which are the timestamps) are: + # doc 3: 1672531200000 (2023-01-01T00:00:00) + # doc 2: 1640995200000 (2022-01-01T00:00:00) + # and their normalized scores based on the provided conf + # will be: + # normalized(doc3) = 1.59989 + # normalized(doc2) = 1.40010 + component: { + retriever: { + standard: { + query: { + function_score: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 1.0 + } + } + ] + } + }, + functions: [ { + script_score: { + script: { + source: "doc['timestamp'].value.millis" + } + } + } ], + "boost_mode": "replace" + } + }, + sort: { + timestamp: { + order: "desc" + } + } + } + }, + weight: 1.0, + normalizer: { + minmax: { + min: 1577836800000, # 2020-01-01T00:00:00 + max: 1735689600000, # 2025-01-01T00:00:00 + lower_bound: 1, + upper_bound: 2 + } + } + } + } + ] + rank_window_size: 2 + size: 2 + + - match: { hits.total.value: 3 } + - length: {hits.hits: 2} + - match: { hits.hits.0._id: "2" } + - close_to: { hits.hits.0._score: { value: 1.9, error: 0.001 } } + - match: { hits.hits.1._id: "3" } + - close_to: { hits.hits.1._score: { value: 1.599, error: 0.001 } } diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java index 94f694a16a262..4fda2bc34599a 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java @@ -25,26 +25,53 @@ public class MinMaxScoreNormalizer extends ScoreNormalizer { public static final ParseField MIN_FIELD = new ParseField("min"); public static final ParseField MAX_FIELD = new ParseField("max"); + public static final ParseField LOWER_BOUND_FIELD = new ParseField("lower_bound"); + public static final ParseField UPPER_BOUND_FIELD = new ParseField("upper_bound"); + + private static final float DEFAULT_LOWER_BOUND = 0f; + private static final float DEFAULT_UPPER_BOUND = 1f; public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { Float min = (Float) args[0]; Float max = (Float) args[1]; - return new MinMaxScoreNormalizer(min, max); + float lowerBound = args[2] == null ? DEFAULT_LOWER_BOUND : (float) args[2]; + float upperBound = args[3] == null ? DEFAULT_UPPER_BOUND : (float) args[3]; + return new MinMaxScoreNormalizer(min, max, lowerBound, upperBound); }); static { PARSER.declareFloat(optionalConstructorArg(), MIN_FIELD); PARSER.declareFloat(optionalConstructorArg(), MAX_FIELD); + PARSER.declareFloat(optionalConstructorArg(), LOWER_BOUND_FIELD); + PARSER.declareFloat(optionalConstructorArg(), UPPER_BOUND_FIELD); } private Float min; private Float max; + private final float lowerBound; + private final float upperBound; + + public MinMaxScoreNormalizer() { + this.min = null; + this.max = null; + this.lowerBound = DEFAULT_LOWER_BOUND; + this.upperBound = DEFAULT_UPPER_BOUND; + } - public MinMaxScoreNormalizer() {} - - public MinMaxScoreNormalizer(Float min, Float max) { + public MinMaxScoreNormalizer(Float min, Float max, float lowerBound, float upperBound) { + if (min != null && max != null && min >= max) { + throw new IllegalArgumentException("[min] must be less than [max]"); + } + if (lowerBound >= upperBound) { + throw new IllegalArgumentException("[lowerBound] must be less than [upperBound]"); + } + if (lowerBound < 0) { + throw new IllegalArgumentException("[lowerBound] must be greater than or equal to 0"); + } this.min = min; this.max = max; + this.lowerBound = lowerBound; + this.upperBound = upperBound; } @Override @@ -57,26 +84,33 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { // create a new array to avoid changing ScoreDocs in place ScoreDoc[] scoreDocs = new ScoreDoc[docs.length]; if (min == null || max == null) { - float localMin = Float.MAX_VALUE; - float localMax = Float.MIN_VALUE; + float xMin = Float.MAX_VALUE; + float xMax = Float.MIN_VALUE; for (ScoreDoc rd : docs) { - if (max == null && rd.score > localMax) { - localMax = rd.score; + if (rd.score > xMax) { + xMax = rd.score; } - if (min == null && rd.score < localMin) { - localMin = rd.score; + if (rd.score < xMin) { + xMin = rd.score; } } if (min == null) { - min = localMin; + min = xMin; } if (max == null) { - max = localMax; + max = xMax; } } - float epsilon = Float.MIN_NORMAL; + if (min > max) { + throw new IllegalArgumentException("[min=" + min + "] must be less than [max=" + max + "]"); + } for (int i = 0; i < docs.length; i++) { - float score = epsilon + ((docs[i].score - min) / (max - min)); + float score; + if (min.equals(max)) { + score = (upperBound + lowerBound) / 2; + } else { + score = ((docs[i].score - min) / (max - min) * (upperBound - lowerBound)) + lowerBound; + } scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex); } return scoreDocs; @@ -90,6 +124,8 @@ public static MinMaxScoreNormalizer fromXContent(XContentParser parser) { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.field(MIN_FIELD.getPreferredName(), min); builder.field(MAX_FIELD.getPreferredName(), max); + builder.field(LOWER_BOUND_FIELD.getPreferredName(), lowerBound); + builder.field(UPPER_BOUND_FIELD.getPreferredName(), upperBound); return builder; } } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 576e3c2d240f8..536f5d4e884a7 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -278,6 +278,7 @@ public int doHashCode() { protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) .trackTotalHits(false) + .trackScores(true) .storedFields(new StoredFieldsContext(false)) .size(rankWindowSize); // apply the pre-filters downstream once diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java index da58dd953c1ef..f2cc20f340447 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java @@ -47,6 +47,8 @@ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder rankResults) { for (int rank = 0; rank < topResults.length; ++rank) { topResults[rank] = sortedResults[rank]; topResults[rank].rank = rank + 1; + topResults[rank].score = Math.max(EPSILON * (topResults.length - rank), topResults[rank].score); } return topResults; } @@ -159,7 +162,9 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept for (var entry : innerRetrievers) { builder.startObject(); builder.startObject(LinearRetrieverComponent.NAME); - builder.field(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName(), entry.retriever()); + builder.startObject(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName()); + entry.retriever().toXContent(builder, params); + builder.endObject(); builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), weights[index]); builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index]); builder.endObject(); diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java index 9c3b0b32cf9af..ce0f4bb61cc45 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java @@ -53,7 +53,6 @@ public class LinearRetrieverComponent implements ToXContentObject { return retrieverBuilder; }, RETRIEVER_FIELD); PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD); - PARSER.declareField(optionalConstructorArg(), (p, c) -> { if (p.currentToken() == XContentParser.Token.VALUE_STRING) { final String normalizer = p.text(); @@ -87,8 +86,4 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public static LinearRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { return PARSER.apply(parser, context); } - - private void setNormalizerField(ScoreNormalizer normalizer) { - this.normalizer = normalizer; - } } diff --git a/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java index 1d6d382726a07..c0d4f84694afb 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java @@ -84,6 +84,6 @@ protected NamedXContentRegistry xContentRegistry() { } private static ScoreNormalizer randomScoreNormalizer() { - return new MinMaxScoreNormalizer(1f, 10f); + return new MinMaxScoreNormalizer(1f, 10f, 1f, 1f); } } From 06d727a3b5510f74a330ec0b0f4c708c70d7a30f Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 15 Jan 2025 17:42:05 +0000 Subject: [PATCH 10/57] [CI] Auto commit changes from spotless --- .../elasticsearch/search/retriever/RetrieversFeatures.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java index 223d8d66fe3c0..c94d845938db7 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java @@ -22,8 +22,6 @@ public class RetrieversFeatures implements FeatureSpecification { @Override public Set getFeatures() { - return Set.of( - LinearRetrieverBuilder.LINEAR_RETRIEVER_SUPPORTED - ); + return Set.of(LinearRetrieverBuilder.LINEAR_RETRIEVER_SUPPORTED); } } From 822ff1d6dfce938d0b5b2fc82e59930db0312c97 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 08:51:22 +0200 Subject: [PATCH 11/57] iter --- .../search/search-your-data/retrievers-examples.asciidoc | 2 +- .../search/retriever/CompoundRetrieverBuilder.java | 3 +-- .../search/retriever/RankDocsRetrieverBuilder.java | 3 +++ .../search/retriever/RetrieversFeatures.java | 4 +--- .../test/entsearch/rules/80_query_rules_retriever.yml | 8 +++----- 5 files changed, 9 insertions(+), 11 deletions(-) diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index 341100095a670..8c79637a5be3e 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -195,7 +195,7 @@ to ensure that we capture the importance of each document within the result set. So, let's now specify the `linear` retriever whose final score is computed as follows: -[source] +[source, text] ---- score = weight(standard) * score(standard) + weight(knn) * score(knn) score = 2 * score(standard) + 1.5 * score(knn) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 536f5d4e884a7..ffd3bfc796f80 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -41,7 +41,6 @@ import java.util.Objects; import static org.elasticsearch.action.ValidateActions.addValidationError; -import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; /** * This abstract retriever defines a compound retriever. The idea is that it is not a leaf-retriever, i.e. it does not @@ -220,7 +219,7 @@ public ActionRequestValidationException validate( boolean allowPartialSearchResults ) { validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults); - final int size = source.size() < 0 ? DEFAULT_SIZE : source.size(); + final int size = source.size(); if (size > rankWindowSize) { validationException = addValidationError( String.format( diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java index 4d3f3fefd4462..f873da8c71506 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java @@ -125,6 +125,9 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder } // ignore prefilters of this level, they were already propagated to children searchSourceBuilder.query(rankQuery); + if (searchSourceBuilder.size() < 0) { + searchSourceBuilder.size(rankWindowSize); + } if (sourceHasMinScore()) { searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore()); } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java index 223d8d66fe3c0..c94d845938db7 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java @@ -22,8 +22,6 @@ public class RetrieversFeatures implements FeatureSpecification { @Override public Set getFeatures() { - return Set.of( - LinearRetrieverBuilder.LINEAR_RETRIEVER_SUPPORTED - ); + return Set.of(LinearRetrieverBuilder.LINEAR_RETRIEVER_SUPPORTED); } } diff --git a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/rules/80_query_rules_retriever.yml b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/rules/80_query_rules_retriever.yml index 089a078c62207..4ce0c55511cbd 100644 --- a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/rules/80_query_rules_retriever.yml +++ b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/rules/80_query_rules_retriever.yml @@ -288,10 +288,9 @@ setup: rank_window_size: 1 - match: { hits.total.value: 3 } + - length: { hits.hits: 1 } - match: { hits.hits.0._id: foo } - match: { hits.hits.0._score: 1.7014124E38 } - - match: { hits.hits.1._score: 0 } - - match: { hits.hits.2._score: 0 } - do: headers: @@ -315,12 +314,10 @@ setup: rank_window_size: 2 - match: { hits.total.value: 3 } + - length: { hits.hits: 2 } - match: { hits.hits.0._id: foo } - match: { hits.hits.0._score: 1.7014124E38 } - match: { hits.hits.1._id: foo2 } - - match: { hits.hits.1._score: 1.7014122E38 } - - match: { hits.hits.2._id: bar_no_rule } - - match: { hits.hits.2._score: 0 } - do: headers: @@ -344,6 +341,7 @@ setup: rank_window_size: 10 - match: { hits.total.value: 3 } + - length: { hits.hits: 3 } - match: { hits.hits.0._id: foo } - match: { hits.hits.0._score: 1.7014124E38 } - match: { hits.hits.1._id: foo2 } From 020cd789096d30c45759144baf5fa18fb56cf234 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 09:11:34 +0200 Subject: [PATCH 12/57] Update docs/changelog/120222.yaml --- docs/changelog/120222.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/120222.yaml diff --git a/docs/changelog/120222.yaml b/docs/changelog/120222.yaml new file mode 100644 index 0000000000000..e8e766565e5bc --- /dev/null +++ b/docs/changelog/120222.yaml @@ -0,0 +1,5 @@ +pr: 120222 +summary: Adding linear retriever to support weighted sums of sub-retrievers +area: "Ranking, Search" +type: enhancement +issues: [] From 8d0583a0b4fe3bf65f6eb76c24142a4dc7e92762 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 09:21:48 +0200 Subject: [PATCH 13/57] iter --- docs/changelog/120222.yaml | 2 +- .../search/search-your-data/retrievers-examples.asciidoc | 2 +- .../test/search.retrievers/40_linear_retriever.yml | 4 ++-- .../retriever/LinearRetrieverBuilderParsingTests.java | 6 +++++- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/changelog/120222.yaml b/docs/changelog/120222.yaml index e8e766565e5bc..c9ded878ac031 100644 --- a/docs/changelog/120222.yaml +++ b/docs/changelog/120222.yaml @@ -1,5 +1,5 @@ pr: 120222 summary: Adding linear retriever to support weighted sums of sub-retrievers -area: "Ranking, Search" +area: "Search" type: enhancement issues: [] diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index 8c79637a5be3e..3194976a647d1 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -240,7 +240,7 @@ GET /retrievers_example/_search } }, "weight": 1.5, - "normalizer": + "normalizer": { "minmax": { "min": 0.5, "max": 1.0 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml index 15f620955190e..ca5ae6c465b54 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml @@ -491,7 +491,7 @@ setup: - close_to: { hits.hits.0._score: { value: 0.0, error: 0.001 } } --- -"should throw when rank_window_size less than default size": +"should throw when rank_window_size less than size": - do: catch: "/\\[linear\\] requires \\[rank_window_size: 2\\] be greater than or equal to \\[size: 10\\]/" search: @@ -528,7 +528,7 @@ setup: } ] rank_window_size: 2 - + size: 10 --- "should respect rank_window_size for normalization and returned hits": - do: diff --git a/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java index c0d4f84694afb..f62586c238b5d 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java @@ -84,6 +84,10 @@ protected NamedXContentRegistry xContentRegistry() { } private static ScoreNormalizer randomScoreNormalizer() { - return new MinMaxScoreNormalizer(1f, 10f, 1f, 1f); + Float min = frequently() ? randomFloat() : null; + Float max = frequently() && min != null ? min + randomFloat() : null; + float lowerBound = random().nextBoolean() ? randomFloat() : 0; + float upperBound = lowerBound + randomFloat(); + return new MinMaxScoreNormalizer(min, max, lowerBound, upperBound); } } From 8ec4110ad83cade99d354a8c0def8e08930f5f78 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 11:00:22 +0200 Subject: [PATCH 14/57] iter --- .../retrievers-examples.asciidoc | 156 +++++++++++++++++- .../normalizer/MinMaxScoreNormalizer.java | 3 + 2 files changed, 153 insertions(+), 6 deletions(-) diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index 3194976a647d1..ed35ed98457e6 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -29,6 +29,9 @@ PUT retrievers_example }, "topic": { "type": "keyword" + }, + "timestamp": { + "type": "date" } } } @@ -39,7 +42,8 @@ POST /retrievers_example/_doc/1 "vector": [0.23, 0.67, 0.89], "text": "Large language models are revolutionizing information retrieval by boosting search precision, deepening contextual understanding, and reshaping user experiences in data-rich environments.", "year": 2024, - "topic": ["llm", "ai", "information_retrieval"] + "topic": ["llm", "ai", "information_retrieval"], + "timestamp": "2021-01-01T12:10:30" } POST /retrievers_example/_doc/2 @@ -47,7 +51,8 @@ POST /retrievers_example/_doc/2 "vector": [0.12, 0.56, 0.78], "text": "Artificial intelligence is transforming medicine, from advancing diagnostics and tailoring treatment plans to empowering predictive patient care for improved health outcomes.", "year": 2023, - "topic": ["ai", "medicine"] + "topic": ["ai", "medicine"], + "timestamp": "2022-01-01T12:10:30" } POST /retrievers_example/_doc/3 @@ -55,7 +60,8 @@ POST /retrievers_example/_doc/3 "vector": [0.45, 0.32, 0.91], "text": "AI is redefining security by enabling advanced threat detection, proactive risk analysis, and dynamic defenses against increasingly sophisticated cyber threats.", "year": 2024, - "topic": ["ai", "security"] + "topic": ["ai", "security"], + "timestamp": "2023-01-01T12:10:30" } POST /retrievers_example/_doc/4 @@ -63,7 +69,8 @@ POST /retrievers_example/_doc/4 "vector": [0.34, 0.21, 0.98], "text": "Elastic introduces Elastic AI Assistant, the open, generative AI sidekick powered by ESRE to democratize cybersecurity and enable users of every skill level.", "year": 2023, - "topic": ["ai", "elastic", "assistant"] + "topic": ["ai", "elastic", "assistant"], + "timestamp": "2024-01-01T12:10:30" } POST /retrievers_example/_doc/5 @@ -71,7 +78,8 @@ POST /retrievers_example/_doc/5 "vector": [0.11, 0.65, 0.47], "text": "Learn how to spin up a deployment of our hosted Elasticsearch Service and use Elastic Observability to gain deeper insight into the behavior of your applications and systems.", "year": 2024, - "topic": ["documentation", "observability", "elastic"] + "topic": ["documentation", "observability", "elastic"], + "timestamp": "2025-01-01T12:10:30" } POST /retrievers_example/_refresh @@ -257,7 +265,7 @@ GET /retrievers_example/_search ---- // TEST -This returns the following response based on the final linearly weighted score for each result. +This returns the following response based on the normalized weighted score for each result. .Example response [%collapsible] @@ -306,6 +314,142 @@ This returns the following response based on the final linearly weighted score f // TESTRESPONSE[s/"score": -3/"score": $body.hits.hits.2._score/] ============== +By normalizing score and leveraging function scores, we can also implement more complex ranking strategies, such as +sorting the results based on their timestamps, assign the timestamp as score, and then normalizing this score to [0, 1] +range where 1 is `today` and `0` is the oldest reference document in the index. +Then, we can easily combine the above with a `knn` retriever for example as follows: + +[source,console] +---- +GET /retrievers_example/_search +{ + "retriever": { + "linear": { + "retrievers": [ + { + "component": { + "retriever": { + "standard": { + "query": { + "function_score": { + "query": { + "term": { + "topic": "ai" + } + }, + "functions": [ + { + "script_score": { + "script": { + "source": "doc['timestamp'].value.millis" + } + } + } + ], + "boost_mode": "replace" + } + }, + "sort": { + "timestamp": { + "order": "asc" + } + } + } + }, + "weight": 2, + "normalizer": { + "minmax": { + "min": "1483228800000" + } + } + } + }, + { + "component": { + "retriever": { + "knn": { + "field": "vector", + "query_vector": [ + 0.23, + 0.67, + 0.89 + ], + "k": 3, + "num_candidates": 5 + } + }, + "weight": 1.5 + } + } + ], + "rank_window_size": 10 + } + }, + "_source": false +} +---- +// TEST + +Which would return the following results: + +.Example response +[%collapsible] +============== +[source,console-result] +---- +{ + "took": 42, + "timed_out": false, + "_shards": { + "total": 1, + "successful": 1, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 5, + "relation": "eq" + }, + "max_score": -1, + "hits": [ + { + "_index": "retrievers_example", + "_id": "2", + "_score": -1 + }, + { + "_index": "retrievers_example", + "_id": "1", + "_score": -2 + }, + { + "_index": "retrievers_example", + "_id": "4", + "_score": -3 + }, + { + "_index": "retrievers_example", + "_id": "3", + "_score": -4 + }, + { + "_index": "retrievers_example", + "_id": "5", + "_score": -5 + } + ] + } +} +---- +// TESTRESPONSE[s/"took": 42/"took": $body.took/] +// TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/] +// TESTRESPONSE[s/"score": -1/"score": $body.hits.hits.0._score/] +// TESTRESPONSE[s/"score": -2/"score": $body.hits.hits.1._score/] +// TESTRESPONSE[s/"score": -3/"score": $body.hits.hits.2._score/] +// TESTRESPONSE[s/"score": -4/"score": $body.hits.hits.5._score/] +// TESTRESPONSE[s/"score": -5/"score": $body.hits.hits.4._score/] +============== [discrete] [[retrievers-examples-collapsing-retriever-results]] diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java index 4fda2bc34599a..a1e91f1a9d9df 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java @@ -81,6 +81,9 @@ public String getName() { @Override public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { + if (docs.length == 0) { + return docs; + } // create a new array to avoid changing ScoreDocs in place ScoreDoc[] scoreDocs = new ScoreDoc[docs.length]; if (min == null || max == null) { From ceaf3b56552e5eecbb33b00cdb8505c2792584a0 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 11:08:06 +0200 Subject: [PATCH 15/57] iter --- server/src/main/java/module-info.java | 1 - 1 file changed, 1 deletion(-) diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index c266b5cfecfcd..143a55f65c09f 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -55,7 +55,6 @@ requires org.apache.lucene.queryparser; requires org.apache.lucene.sandbox; requires org.apache.lucene.suggest; - requires java.desktop; exports org.elasticsearch; exports org.elasticsearch.action; From ed78bf24406236ab2725bc45edb8049a3c99d956 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 11:30:58 +0200 Subject: [PATCH 16/57] iter --- docs/reference/rest-api/common-parms.asciidoc | 4 ++-- .../search/search-your-data/retrievers-examples.asciidoc | 8 ++++---- .../main/java/org/elasticsearch/plugins/SearchPlugin.java | 6 ++---- .../search/normalizer/ScoreNormalizerParser.java | 8 ++------ 4 files changed, 10 insertions(+), 16 deletions(-) diff --git a/docs/reference/rest-api/common-parms.asciidoc b/docs/reference/rest-api/common-parms.asciidoc index 443dd4aa3ae2e..4de3c0312e5d6 100644 --- a/docs/reference/rest-api/common-parms.asciidoc +++ b/docs/reference/rest-api/common-parms.asciidoc @@ -1383,14 +1383,14 @@ The weight that each score of this retriever's top docs will be multiplied with. + Specifies how we will normalize the retriever's scores, before applying the specified `weight`. We can either provide a string reference to use with the default values or further configure any normalizer -using its specific properties. Available values are: `minmax`, `none`. Defaults to `none`. +using its specific properties. Available values are: `minmax`, and `none`. Defaults to `none`. ** `none` : takes no argument ** `minmax` : A `MinMaxScoreNormalizer` that normalizes scores based on the following formula + ``` -score = (score - min) / (max - min) * (upper_bound - lower_bound) + lower_bound +score = ((score - min) / (max - min)) * (upper_bound - lower_bound) + lower_bound ``` Available properties are: *** `min`:: diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index ed35ed98457e6..d10b7823bb3c7 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -314,10 +314,10 @@ This returns the following response based on the normalized weighted score for e // TESTRESPONSE[s/"score": -3/"score": $body.hits.hits.2._score/] ============== -By normalizing score and leveraging function scores, we can also implement more complex ranking strategies, such as -sorting the results based on their timestamps, assign the timestamp as score, and then normalizing this score to [0, 1] -range where 1 is `today` and `0` is the oldest reference document in the index. -Then, we can easily combine the above with a `knn` retriever for example as follows: +By normalizing scores and leveraging `function_score` queries, we can also implement more complex ranking strategies, +such as sorting results based on their timestamps, assign the timestamp as a score, and then normalizing this score to +[0, 1] range where 1 is `today` and `0` is the oldest reference document in the index. +Then, we can easily combine the above with a `knn` retriever as follows: [source,console] ---- diff --git a/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java b/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java index ac706dab459ea..855410f1b9eaa 100644 --- a/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java +++ b/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java @@ -301,8 +301,7 @@ public RetrieverSpec(ParseField name, RetrieverParser parser) { /** * Specification of custom {@link RetrieverBuilder}. * - * @param name the name by which this retriever might be parsed or deserialized. Make sure that the retriever builder returns - * this name for {@link NamedWriteable#getWriteableName()}. + * @param name the name by which this retriever might be parsed or deserialized. * @param parser the parser the reads the retriever builder from xcontent */ public RetrieverSpec(String name, RetrieverParser parser) { @@ -330,8 +329,7 @@ class ScoreNormalizerSpec { * Specification of custom {@link ScoreNormalizer}. * * @param name holds the names by which this score normalizer might be parsed. The {@link ParseField#getPreferredName()} - * is special as it is the name by under which the reader is registered. So it is the name that the normalizer - * should use as its {@link NamedWriteable#getWriteableName()} too. + * is special as it is the name by under which the reader is registered. * @param parser the parser the reads the retriever builder from xcontent */ public ScoreNormalizerSpec(ParseField name, ScoreNormalizerParser parser) { diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizerParser.java b/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizerParser.java index f439fef5f5140..6b98064c86b0c 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizerParser.java +++ b/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizerParser.java @@ -9,8 +9,6 @@ package org.elasticsearch.search.normalizer; -import org.elasticsearch.search.retriever.RetrieverBuilder; -import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; @@ -23,11 +21,9 @@ public interface ScoreNormalizerParser { /** - * Creates a new {@link RetrieverBuilder} from the retriever held by the + * Creates a new {@link ScoreNormalizerParser} from the normalizer held by the * {@link XContentParser}. The state on the parser contained in this context - * will be changed as a side effect of this method call. The - * {@link RetrieverParserContext} tracks usage of retriever features and - * queries when available. + * will be changed as a side effect of this method call. */ SN fromXContent(XContentParser parser) throws IOException; } From a70b0d631f6981479baffc54726a770e09bf071a Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 12:05:33 +0200 Subject: [PATCH 17/57] iter --- .../search-your-data/retrievers-examples.asciidoc | 2 +- .../normalizer/IdentityScoreNormalizer.java | 3 +-- .../search/normalizer/MinMaxScoreNormalizer.java | 11 +++++++---- .../search/normalizer/ScoreNormalizer.java | 15 +++++++++++++++ .../search/retriever/LinearRetrieverBuilder.java | 4 +--- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index d10b7823bb3c7..20d82fa5ad62d 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -447,7 +447,7 @@ Which would return the following results: // TESTRESPONSE[s/"score": -1/"score": $body.hits.hits.0._score/] // TESTRESPONSE[s/"score": -2/"score": $body.hits.hits.1._score/] // TESTRESPONSE[s/"score": -3/"score": $body.hits.hits.2._score/] -// TESTRESPONSE[s/"score": -4/"score": $body.hits.hits.5._score/] +// TESTRESPONSE[s/"score": -4/"score": $body.hits.hits.3._score/] // TESTRESPONSE[s/"score": -5/"score": $body.hits.hits.4._score/] ============== diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/IdentityScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/normalizer/IdentityScoreNormalizer.java index 234ccec90ddff..9a5c7af0c390f 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/IdentityScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/normalizer/IdentityScoreNormalizer.java @@ -44,8 +44,7 @@ public static IdentityScoreNormalizer fromXContent(XContentParser parser) { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + public void doToXContent(XContentBuilder builder, Params params) throws IOException { // no-op - return builder; } } diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java index a1e91f1a9d9df..c4d16263ee94c 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java @@ -124,11 +124,14 @@ public static MinMaxScoreNormalizer fromXContent(XContentParser parser) { } @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(MIN_FIELD.getPreferredName(), min); - builder.field(MAX_FIELD.getPreferredName(), max); + public void doToXContent(XContentBuilder builder, Params params) throws IOException { + if (min != null) { + builder.field(MIN_FIELD.getPreferredName(), min); + } + if (max != null) { + builder.field(MAX_FIELD.getPreferredName(), max); + } builder.field(LOWER_BOUND_FIELD.getPreferredName(), lowerBound); builder.field(UPPER_BOUND_FIELD.getPreferredName(), upperBound); - return builder; } } diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java index 643b896050eb7..f21e105acb2c6 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java @@ -11,6 +11,9 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; /** * A no-op {@link ScoreNormalizer} that does not modify the scores. @@ -28,6 +31,18 @@ public static ScoreNormalizer valueOf(String normalizer) { } } + protected abstract void doToXContent(XContentBuilder builder, Params params) throws IOException; + + public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startObject(getName()); + doToXContent(builder, params); + builder.endObject(); + builder.endObject(); + + return builder; + } + public abstract String getName(); public abstract ScoreDoc[] normalizeScores(ScoreDoc[] docs); diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java index f2cc20f340447..d71c93f5b3cae 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java @@ -162,9 +162,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept for (var entry : innerRetrievers) { builder.startObject(); builder.startObject(LinearRetrieverComponent.NAME); - builder.startObject(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName()); - entry.retriever().toXContent(builder, params); - builder.endObject(); + builder.field(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName(), entry.retriever()); builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), weights[index]); builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index]); builder.endObject(); From 9304c7b00c14594f6b6c3865780edb0b4f9bba2a Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 12:24:13 +0200 Subject: [PATCH 18/57] iter --- .../retrievers-examples.asciidoc | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index 20d82fa5ad62d..df7f58b27e843 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -13,13 +13,20 @@ To begin with, lets create the `retrievers_example` index, and add some document ---- PUT retrievers_example { + "settings": { + "number_of_shards": 1, + "number_of_replicas": 0 + }, "mappings": { "properties": { "vector": { "type": "dense_vector", "dims": 3, "similarity": "l2_norm", - "index": true + "index": true, + "index_options": { + "type": "flat" + } }, "text": { "type": "text" @@ -408,35 +415,30 @@ Which would return the following results: }, "hits": { "total": { - "value": 5, + "value": 4, "relation": "eq" }, "max_score": -1, "hits": [ { "_index": "retrievers_example", - "_id": "2", + "_id": "3", "_score": -1 }, { "_index": "retrievers_example", - "_id": "1", + "_id": "2", "_score": -2 }, { "_index": "retrievers_example", - "_id": "4", + "_id": "1", "_score": -3 }, { "_index": "retrievers_example", - "_id": "3", + "_id": "4", "_score": -4 - }, - { - "_index": "retrievers_example", - "_id": "5", - "_score": -5 } ] } @@ -448,7 +450,6 @@ Which would return the following results: // TESTRESPONSE[s/"score": -2/"score": $body.hits.hits.1._score/] // TESTRESPONSE[s/"score": -3/"score": $body.hits.hits.2._score/] // TESTRESPONSE[s/"score": -4/"score": $body.hits.hits.3._score/] -// TESTRESPONSE[s/"score": -5/"score": $body.hits.hits.4._score/] ============== [discrete] From 05aae700b1c2c01482bb054a8ef01d6e5962ac64 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 12:36:01 +0200 Subject: [PATCH 19/57] iter --- .../retrievers-examples.asciidoc | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index df7f58b27e843..a4053d95eaea5 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -270,7 +270,7 @@ GET /retrievers_example/_search "_source": false } ---- -// TEST +// TEST[continued] This returns the following response based on the normalized weighted score for each result. @@ -307,7 +307,7 @@ This returns the following response based on the normalized weighted score for e }, { "_index": "retrievers_example", - "_id": "5", + "_id": "3", "_score": -3 } ] @@ -316,9 +316,9 @@ This returns the following response based on the normalized weighted score for e ---- // TESTRESPONSE[s/"took": 42/"took": $body.took/] // TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/] -// TESTRESPONSE[s/"score": -1/"score": $body.hits.hits.0._score/] -// TESTRESPONSE[s/"score": -2/"score": $body.hits.hits.1._score/] -// TESTRESPONSE[s/"score": -3/"score": $body.hits.hits.2._score/] +// TESTRESPONSE[s/"_score": -1/"_score": $body.hits.hits.0._score/] +// TESTRESPONSE[s/"_score": -2/"_score": $body.hits.hits.1._score/] +// TESTRESPONSE[s/"_score": -3/"_score": $body.hits.hits.2._score/] ============== By normalizing scores and leveraging `function_score` queries, we can also implement more complex ranking strategies, @@ -395,7 +395,7 @@ GET /retrievers_example/_search "_source": false } ---- -// TEST +// TEST[continued] Which would return the following results: @@ -446,10 +446,10 @@ Which would return the following results: ---- // TESTRESPONSE[s/"took": 42/"took": $body.took/] // TESTRESPONSE[s/"max_score": -1/"max_score": $body.hits.max_score/] -// TESTRESPONSE[s/"score": -1/"score": $body.hits.hits.0._score/] -// TESTRESPONSE[s/"score": -2/"score": $body.hits.hits.1._score/] -// TESTRESPONSE[s/"score": -3/"score": $body.hits.hits.2._score/] -// TESTRESPONSE[s/"score": -4/"score": $body.hits.hits.3._score/] +// TESTRESPONSE[s/"_score": -1/"_score": $body.hits.hits.0._score/] +// TESTRESPONSE[s/"_score": -2/"_score": $body.hits.hits.1._score/] +// TESTRESPONSE[s/"_score": -3/"_score": $body.hits.hits.2._score/] +// TESTRESPONSE[s/"_score": -4/"_score": $body.hits.hits.3._score/] ============== [discrete] From 4fde94788609a2f507601888ec9a7ed9ccaebf15 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 15:49:20 +0200 Subject: [PATCH 20/57] addressing PR comments - removing lower_bound and upper_bound params --- docs/reference/rest-api/common-parms.asciidoc | 10 +----- .../normalizer/MinMaxScoreNormalizer.java | 31 +++---------------- .../LinearRetrieverBuilderParsingTests.java | 4 +-- 3 files changed, 6 insertions(+), 39 deletions(-) diff --git a/docs/reference/rest-api/common-parms.asciidoc b/docs/reference/rest-api/common-parms.asciidoc index 4de3c0312e5d6..8029543941cca 100644 --- a/docs/reference/rest-api/common-parms.asciidoc +++ b/docs/reference/rest-api/common-parms.asciidoc @@ -1390,7 +1390,7 @@ using its specific properties. Available values are: `minmax`, and `none`. Defau A `MinMaxScoreNormalizer` that normalizes scores based on the following formula + ``` -score = ((score - min) / (max - min)) * (upper_bound - lower_bound) + lower_bound +score = (score - min) / (max - min) ``` Available properties are: *** `min`:: @@ -1403,14 +1403,6 @@ The minimum value of the original scores. Defaults to result set's true min valu + The maximum value of the original scores. Defaults to result set's true max value. -*** `lower_bound`:: -(Optional, float) -+ -The minimum value that the retriever's normalized scores can take. Defaults to 0. -*** `upper_bound`:: -(Optional, float) -+ -The maximum value that the retriever's normalized scores can take. Defaults to 1. See also <> using a linear retriever on how to independently configure and apply normalizers to retrievers. diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java index c4d16263ee94c..188f2df0fb54d 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java @@ -25,53 +25,32 @@ public class MinMaxScoreNormalizer extends ScoreNormalizer { public static final ParseField MIN_FIELD = new ParseField("min"); public static final ParseField MAX_FIELD = new ParseField("max"); - public static final ParseField LOWER_BOUND_FIELD = new ParseField("lower_bound"); - public static final ParseField UPPER_BOUND_FIELD = new ParseField("upper_bound"); - - private static final float DEFAULT_LOWER_BOUND = 0f; - private static final float DEFAULT_UPPER_BOUND = 1f; public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { Float min = (Float) args[0]; Float max = (Float) args[1]; - float lowerBound = args[2] == null ? DEFAULT_LOWER_BOUND : (float) args[2]; - float upperBound = args[3] == null ? DEFAULT_UPPER_BOUND : (float) args[3]; - return new MinMaxScoreNormalizer(min, max, lowerBound, upperBound); + return new MinMaxScoreNormalizer(min, max); }); static { PARSER.declareFloat(optionalConstructorArg(), MIN_FIELD); PARSER.declareFloat(optionalConstructorArg(), MAX_FIELD); - PARSER.declareFloat(optionalConstructorArg(), LOWER_BOUND_FIELD); - PARSER.declareFloat(optionalConstructorArg(), UPPER_BOUND_FIELD); } private Float min; private Float max; - private final float lowerBound; - private final float upperBound; public MinMaxScoreNormalizer() { this.min = null; this.max = null; - this.lowerBound = DEFAULT_LOWER_BOUND; - this.upperBound = DEFAULT_UPPER_BOUND; } - public MinMaxScoreNormalizer(Float min, Float max, float lowerBound, float upperBound) { + public MinMaxScoreNormalizer(Float min, Float max) { if (min != null && max != null && min >= max) { throw new IllegalArgumentException("[min] must be less than [max]"); } - if (lowerBound >= upperBound) { - throw new IllegalArgumentException("[lowerBound] must be less than [upperBound]"); - } - if (lowerBound < 0) { - throw new IllegalArgumentException("[lowerBound] must be greater than or equal to 0"); - } this.min = min; this.max = max; - this.lowerBound = lowerBound; - this.upperBound = upperBound; } @Override @@ -110,9 +89,9 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { for (int i = 0; i < docs.length; i++) { float score; if (min.equals(max)) { - score = (upperBound + lowerBound) / 2; + score = (max + min) / 2; } else { - score = ((docs[i].score - min) / (max - min) * (upperBound - lowerBound)) + lowerBound; + score = (docs[i].score - min) / (max - min); } scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex); } @@ -131,7 +110,5 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept if (max != null) { builder.field(MAX_FIELD.getPreferredName(), max); } - builder.field(LOWER_BOUND_FIELD.getPreferredName(), lowerBound); - builder.field(UPPER_BOUND_FIELD.getPreferredName(), upperBound); } } diff --git a/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java index f62586c238b5d..bfcc627c40773 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java @@ -86,8 +86,6 @@ protected NamedXContentRegistry xContentRegistry() { private static ScoreNormalizer randomScoreNormalizer() { Float min = frequently() ? randomFloat() : null; Float max = frequently() && min != null ? min + randomFloat() : null; - float lowerBound = random().nextBoolean() ? randomFloat() : 0; - float upperBound = lowerBound + randomFloat(); - return new MinMaxScoreNormalizer(min, max, lowerBound, upperBound); + return new MinMaxScoreNormalizer(min, max); } } From e71b25e2011c0061be20b28a3db4844c40778d12 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 16:19:29 +0200 Subject: [PATCH 21/57] fix test --- .../test/search.retrievers/40_linear_retriever.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml index ca5ae6c465b54..d9c6750979021 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml @@ -1073,10 +1073,7 @@ setup: normalizer: { minmax: { min: 1577836800000, # 2020-01-01T00:00:00 - max: 1735689600000, # 2025-01-01T00:00:00 - lower_bound: 1, - upper_bound: 2 - } + max: 1735689600000 # 2025-01-01T00:00:00 } } } From 21a78d527337786d6bb2a08d2bc4b73882f12b34 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 17:24:39 +0200 Subject: [PATCH 22/57] addressing PR comments - removing ScoreNormalizerParser --- .../search.retrievers/40_linear_retriever.yml | 17 ++++---- .../elasticsearch/plugins/SearchPlugin.java | 41 ------------------- .../elasticsearch/search/SearchModule.java | 21 ---------- .../normalizer/MinMaxScoreNormalizer.java | 2 +- .../search/normalizer/ScoreNormalizer.java | 10 +++++ .../normalizer/ScoreNormalizerParser.java | 29 ------------- .../retriever/LinearRetrieverComponent.java | 5 ++- 7 files changed, 23 insertions(+), 102 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizerParser.java diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml index d9c6750979021..8794c8bfd7f44 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml @@ -281,7 +281,7 @@ setup: --- "should throw on unknown normalizer": - do: - catch: /unknown field \[aardvark\]/ + catch: /Unknown normalizer \[aardvark\]/ search: index: test body: @@ -968,14 +968,14 @@ setup: { term: { keyword: { - value: "one" # this will give doc 1 a normalized score of 0.5 + value: "one" # this will give doc 1 a normalized score of 10 } } }, { term: { keyword: { - value: "two" # this will give doc 2 a normalized score of 0.5 + value: "two" # this will give doc 2 a normalized score of 10 } } } ] @@ -1003,8 +1003,8 @@ setup: # doc 2: 1640995200000 (2022-01-01T00:00:00) # and their normalized scores based on the provided conf # will be: - # normalized(doc3) = 1.59989 - # normalized(doc2) = 1.40010 + # normalized(doc3) = 0.59989 + # normalized(doc2) = 0.40010 component: { retriever: { standard: { @@ -1074,6 +1074,7 @@ setup: minmax: { min: 1577836800000, # 2020-01-01T00:00:00 max: 1735689600000 # 2025-01-01T00:00:00 + } } } } @@ -1084,6 +1085,6 @@ setup: - match: { hits.total.value: 3 } - length: {hits.hits: 2} - match: { hits.hits.0._id: "2" } - - close_to: { hits.hits.0._score: { value: 1.9, error: 0.001 } } - - match: { hits.hits.1._id: "3" } - - close_to: { hits.hits.1._score: { value: 1.599, error: 0.001 } } + - close_to: { hits.hits.0._score: { value: 10.4001, error: 0.001 } } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1._score: 10 } diff --git a/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java b/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java index 855410f1b9eaa..105c64200d637 100644 --- a/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java +++ b/server/src/main/java/org/elasticsearch/plugins/SearchPlugin.java @@ -37,8 +37,6 @@ import org.elasticsearch.search.fetch.FetchSubPhase; import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.search.internal.ShardSearchRequest; -import org.elasticsearch.search.normalizer.ScoreNormalizer; -import org.elasticsearch.search.normalizer.ScoreNormalizerParser; import org.elasticsearch.search.rescore.Rescorer; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; @@ -317,45 +315,6 @@ public RetrieverParser getParser() { } } - /** - * Specification of custom {@link ScoreNormalizer}. - */ - class ScoreNormalizerSpec { - - private final ParseField name; - private final ScoreNormalizerParser parser; - - /** - * Specification of custom {@link ScoreNormalizer}. - * - * @param name holds the names by which this score normalizer might be parsed. The {@link ParseField#getPreferredName()} - * is special as it is the name by under which the reader is registered. - * @param parser the parser the reads the retriever builder from xcontent - */ - public ScoreNormalizerSpec(ParseField name, ScoreNormalizerParser parser) { - this.name = name; - this.parser = parser; - } - - /** - * Specification of custom {@link ScoreNormalizer}. - * - * @param name the name by which this normalizer might be parsed or deserialized - * @param parser the parser the reads the {@code ScoreNormalizer} from xcontent - */ - public ScoreNormalizerSpec(String name, ScoreNormalizerParser parser) { - this(new ParseField(name), parser); - } - - public ParseField getName() { - return name; - } - - public ScoreNormalizerParser getParser() { - return parser; - } - } - /** * Specification of custom {@link Query}. */ diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index ff5b062e1cd5b..00639fe8fed91 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -87,7 +87,6 @@ import org.elasticsearch.plugins.SearchPlugin.RescorerSpec; import org.elasticsearch.plugins.SearchPlugin.RetrieverSpec; import org.elasticsearch.plugins.SearchPlugin.ScoreFunctionSpec; -import org.elasticsearch.plugins.SearchPlugin.ScoreNormalizerSpec; import org.elasticsearch.plugins.SearchPlugin.SearchExtSpec; import org.elasticsearch.plugins.SearchPlugin.SignificanceHeuristicSpec; import org.elasticsearch.plugins.SearchPlugin.SuggesterSpec; @@ -225,9 +224,6 @@ import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.search.fetch.subphase.highlight.PlainHighlighter; import org.elasticsearch.search.internal.ShardSearchRequest; -import org.elasticsearch.search.normalizer.IdentityScoreNormalizer; -import org.elasticsearch.search.normalizer.MinMaxScoreNormalizer; -import org.elasticsearch.search.normalizer.ScoreNormalizer; import org.elasticsearch.search.rank.LinearRankDoc; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankShardResult; @@ -356,7 +352,6 @@ public SearchModule(Settings settings, List plugins, TelemetryProv highlighters = setupHighlighters(settings, plugins); registerScoreFunctions(plugins); registerRetrieverParsers(plugins); - registerScoreNormalizerParsers(plugins); registerQueryParsers(plugins); registerRescorers(plugins); registerRankers(); @@ -1096,11 +1091,6 @@ private void registerRetrieverParsers(List plugins) { registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever); } - private void registerScoreNormalizerParsers(List plugins) { - registerScoreNormalizer(new ScoreNormalizerSpec<>(MinMaxScoreNormalizer.NAME, MinMaxScoreNormalizer::fromXContent)); - registerScoreNormalizer(new ScoreNormalizerSpec<>(IdentityScoreNormalizer.NAME, IdentityScoreNormalizer::fromXContent)); - } - private void registerQueryParsers(List plugins) { registerQuery(new QuerySpec<>(MatchQueryBuilder.NAME, MatchQueryBuilder::new, MatchQueryBuilder::fromXContent)); registerQuery(new QuerySpec<>(MatchPhraseQueryBuilder.NAME, MatchPhraseQueryBuilder::new, MatchPhraseQueryBuilder::fromXContent)); @@ -1277,17 +1267,6 @@ private void registerRetriever(RetrieverSpec spec) { ); } - private void registerScoreNormalizer(ScoreNormalizerSpec spec) { - namedXContents.add( - new NamedXContentRegistry.Entry( - ScoreNormalizer.class, - spec.getName(), - (p, c) -> spec.getParser().fromXContent(p), - spec.getName().getForRestApiVersion() - ) - ); - } - private void registerQuery(QuerySpec spec) { namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, spec.getName().getPreferredName(), spec.getReader())); namedXContents.add( diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java index 188f2df0fb54d..556ab1200b939 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java @@ -89,7 +89,7 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { for (int i = 0; i < docs.length; i++) { float score; if (min.equals(max)) { - score = (max + min) / 2; + score = min; } else { score = (docs[i].score - min) / (max - min); } diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java index f21e105acb2c6..d12f8e4497b08 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java @@ -12,6 +12,7 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; @@ -31,6 +32,15 @@ public static ScoreNormalizer valueOf(String normalizer) { } } + public static ScoreNormalizer parse(String normalizer, XContentParser p) { + if (MinMaxScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) { + return MinMaxScoreNormalizer.fromXContent(p); + } else if (IdentityScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) { + return IdentityScoreNormalizer.fromXContent(p); + } + throw new IllegalArgumentException("Unknown normalizer [" + normalizer + "]"); + } + protected abstract void doToXContent(XContentBuilder builder, Params params) throws IOException; public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizerParser.java b/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizerParser.java deleted file mode 100644 index 6b98064c86b0c..0000000000000 --- a/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizerParser.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". - */ - -package org.elasticsearch.search.normalizer; - -import org.elasticsearch.xcontent.XContentParser; - -import java.io.IOException; - -/** - * Defines a ScoreNormalizer parser that is able to parse {@link ScoreNormalizer}s - * from {@link org.elasticsearch.xcontent.XContent}. - */ -@FunctionalInterface -public interface ScoreNormalizerParser { - - /** - * Creates a new {@link ScoreNormalizerParser} from the normalizer held by the - * {@link XContentParser}. The state on the parser contained in this context - * will be changed as a side effect of this method call. - */ - SN fromXContent(XContentParser parser) throws IOException; -} diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java index ce0f4bb61cc45..33b706292bb5e 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java @@ -59,9 +59,10 @@ public class LinearRetrieverComponent implements ToXContentObject { return ScoreNormalizer.valueOf(normalizer); } else if (p.currentToken() == XContentParser.Token.START_OBJECT) { p.nextToken(); - ScoreNormalizer normalizer = p.namedObject(ScoreNormalizer.class, p.currentName(), c); + final String normalizerName = p.currentName(); + ScoreNormalizer scoreNormalizer = ScoreNormalizer.parse(normalizerName, p); p.nextToken(); - return normalizer; + return scoreNormalizer; } throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); }, NORMALIZER_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING); From 77fd4e189c46a7fb31acd55d8e88a692e5a22201 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 17:31:28 +0200 Subject: [PATCH 23/57] removing export from module info --- server/src/main/java/module-info.java | 1 - 1 file changed, 1 deletion(-) diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 143a55f65c09f..2a68b65bcdccb 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -355,7 +355,6 @@ exports org.elasticsearch.search.fetch.subphase.highlight; exports org.elasticsearch.search.internal; exports org.elasticsearch.search.lookup; - exports org.elasticsearch.search.normalizer; exports org.elasticsearch.search.profile; exports org.elasticsearch.search.profile.aggregation; exports org.elasticsearch.search.profile.dfs; From a79b2806f9921991cd069bd6ef1be98bfd6075a5 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 17:59:55 +0200 Subject: [PATCH 24/57] iter --- docs/reference/search/retriever.asciidoc | 3 +- .../search.retrievers/40_linear_retriever.yml | 9 +++-- .../normalizer/MinMaxScoreNormalizer.java | 34 +++++++++++-------- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/docs/reference/search/retriever.asciidoc b/docs/reference/search/retriever.asciidoc index 1132d69360a29..323b25f5e3c6b 100644 --- a/docs/reference/search/retriever.asciidoc +++ b/docs/reference/search/retriever.asciidoc @@ -269,7 +269,8 @@ This value must be fewer than or equal to `num_candidates`. [[linear-retriever]] ==== Linear Retriever A retriever that normalizes and linearly combines the scores of other retrievers. If the final scores produced after the -weighted combination of all sub-retrievers are negative, they are set to increments of `1e-6` to avoid negative scores. +weighted combination of all sub-retrievers are negative, a corrective factor is applied equal to the minimum score, +so all scores are positive. ===== Parameters diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml index 8794c8bfd7f44..1245dc1024cea 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml @@ -243,7 +243,6 @@ setup: } } }, - weight: 10.0, normalizer: { minmax: { min: 42, @@ -270,13 +269,13 @@ setup: - match: { hits.total.value: 4 } - match: { hits.hits.0._id: "4" } - - close_to: { hits.hits.0._score: { value: 2.93103, error: 0.001 } } + - close_to: { hits.hits.0._score: { value: 50.2931, error: 0.001 } } - match: { hits.hits.1._id: "1" } - - close_to: { hits.hits.1._score: { value: 0.000003, error: 0.001 } } + - close_to: { hits.hits.1._score: { value: 40.4482, error: 0.001 } } - match: { hits.hits.2._id: "2" } - - close_to: { hits.hits.2._score: { value: 0.000002, error: 0.001 } } + - close_to: { hits.hits.2._score: { value: 40.4310, error: 0.001 } } - match: { hits.hits.3._id: "3" } - - close_to: { hits.hits.3._score: { value: 0.000001, error: 0.001 } } + - close_to: { hits.hits.3._score: { value: 40.3620, error: 0.001 } } --- "should throw on unknown normalizer": diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java index 556ab1200b939..11de6d77c86e8 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java @@ -65,24 +65,28 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { } // create a new array to avoid changing ScoreDocs in place ScoreDoc[] scoreDocs = new ScoreDoc[docs.length]; - if (min == null || max == null) { - float xMin = Float.MAX_VALUE; - float xMax = Float.MIN_VALUE; - for (ScoreDoc rd : docs) { - if (rd.score > xMax) { - xMax = rd.score; - } - if (rd.score < xMin) { - xMin = rd.score; - } + float correction = 0f; + float xMin = Float.MAX_VALUE; + float xMax = Float.MIN_VALUE; + for (ScoreDoc rd : docs) { + if (rd.score > xMax) { + xMax = rd.score; } - if (min == null) { - min = xMin; + if (rd.score < xMin) { + xMin = rd.score; } - if (max == null) { - max = xMax; + } + if (min == null) { + min = xMin; + }else { + if (min > xMin) { + correction = min - xMin; } } + if (max == null) { + max = xMax; + } + if (min > max) { throw new IllegalArgumentException("[min=" + min + "] must be less than [max=" + max + "]"); } @@ -91,7 +95,7 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { if (min.equals(max)) { score = min; } else { - score = (docs[i].score - min) / (max - min); + score = correction + (docs[i].score - min) / (max - min); } scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex); } From ff0c8c365a66cbd311bd240d7ff810428d542255 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 16 Jan 2025 18:10:49 +0200 Subject: [PATCH 25/57] spotless --- .../elasticsearch/search/normalizer/MinMaxScoreNormalizer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java index 11de6d77c86e8..fa7c323f3b67f 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java @@ -78,7 +78,7 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { } if (min == null) { min = xMin; - }else { + } else { if (min > xMin) { correction = min - xMin; } From 8d53d73e6d0bc9cbea187c67121edec66413f283 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Fri, 17 Jan 2025 11:49:28 +0200 Subject: [PATCH 26/57] addressing PR comments - updating linear component parsing --- docs/reference/rest-api/common-parms.asciidoc | 10 +- .../retrievers-examples.asciidoc | 144 +- .../search.retrievers/40_linear_retriever.yml | 1229 ++++++++--------- .../IdentityScoreNormalizer.java | 2 +- .../retriever/LinearRetrieverBuilder.java | 12 +- .../retriever/LinearRetrieverComponent.java | 94 +- .../MinMaxScoreNormalizer.java | 2 +- .../ScoreNormalizer.java | 2 +- .../LinearRetrieverBuilderParsingTests.java | 2 - 9 files changed, 709 insertions(+), 788 deletions(-) rename server/src/main/java/org/elasticsearch/search/{normalizer => retriever}/IdentityScoreNormalizer.java (97%) rename server/src/main/java/org/elasticsearch/search/{normalizer => retriever}/MinMaxScoreNormalizer.java (98%) rename server/src/main/java/org/elasticsearch/search/{normalizer => retriever}/ScoreNormalizer.java (98%) diff --git a/docs/reference/rest-api/common-parms.asciidoc b/docs/reference/rest-api/common-parms.asciidoc index 8029543941cca..2616fe9d59253 100644 --- a/docs/reference/rest-api/common-parms.asciidoc +++ b/docs/reference/rest-api/common-parms.asciidoc @@ -1358,14 +1358,14 @@ according to each retriever's specifications. end::compound-retriever-filter[] tag::linear-retriever-components[] -`components`:: -(Required, array of `component` objects) +`retrievers`:: +(Required, array of objects) + -A list of the components, i.e. the sub-retrievers' configuration, that we will take into account and whose result sets -we will merge through a weighted sum. Each component can have a different weight and normalization depending +A list of the sub-retrievers' configuration, that we will take into account and whose result sets +we will merge through a weighted sum. Each configuration can have a different weight and normalization depending on the specified retriever. -Each `component` entry specifies the following parameters: +Each entry specifies the following parameters: * `retriever`:: (Required, a <> object) diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index a4053d95eaea5..cf1eb330047c6 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -225,41 +225,37 @@ GET /retrievers_example/_search "linear": { "retrievers": [ { - "component": { - "retriever": { - "standard": { - "query": { - "query_string": { - "query": "(information retrieval) OR (artificial intelligence)", - "default_field": "text" - } + "retriever": { + "standard": { + "query": { + "query_string": { + "query": "(information retrieval) OR (artificial intelligence)", + "default_field": "text" } } - }, - "weight": 2, - "normalizer": "minmax" - } + } + }, + "weight": 2, + "normalizer": "minmax" }, { - "component": { - "retriever": { - "knn": { - "field": "vector", - "query_vector": [ - 0.23, - 0.67, - 0.89 - ], - "k": 3, - "num_candidates": 5 - } - }, - "weight": 1.5, - "normalizer": { - "minmax": { - "min": 0.5, - "max": 1.0 - } + "retriever": { + "knn": { + "field": "vector", + "query_vector": [ + 0.23, + 0.67, + 0.89 + ], + "k": 3, + "num_candidates": 5 + } + }, + "weight": 1.5, + "normalizer": { + "minmax": { + "min": 0.5, + "max": 1.0 } } } @@ -334,59 +330,55 @@ GET /retrievers_example/_search "linear": { "retrievers": [ { - "component": { - "retriever": { - "standard": { - "query": { - "function_score": { - "query": { - "term": { - "topic": "ai" - } - }, - "functions": [ - { - "script_score": { - "script": { - "source": "doc['timestamp'].value.millis" - } + "retriever": { + "standard": { + "query": { + "function_score": { + "query": { + "term": { + "topic": "ai" + } + }, + "functions": [ + { + "script_score": { + "script": { + "source": "doc['timestamp'].value.millis" } } - ], - "boost_mode": "replace" - } - }, - "sort": { - "timestamp": { - "order": "asc" - } + } + ], + "boost_mode": "replace" + } + }, + "sort": { + "timestamp": { + "order": "asc" } } - }, - "weight": 2, - "normalizer": { - "minmax": { - "min": "1483228800000" - } + } + }, + "weight": 2, + "normalizer": { + "minmax": { + "min": "1483228800000" } } }, { - "component": { - "retriever": { - "knn": { - "field": "vector", - "query_vector": [ - 0.23, - 0.67, - 0.89 - ], - "k": 3, - "num_candidates": 5 - } - }, - "weight": 1.5 - } + "retriever": { + "knn": { + "field": "vector", + "query_vector": [ + 0.23, + 0.67, + 0.89 + ], + "k": 3, + "num_candidates": 5 + } + }, + "weight": 1.5 } ], "rank_window_size": 10 diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml index 1245dc1024cea..d82d2f4d41c4f 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml @@ -46,38 +46,34 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" } - }, - boost: 10.0 - } + } + }, + boost: 10.0 } } - }, - weight: 0.5 - } + } + }, + weight: 0.5 }, { - component: { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 } ] @@ -97,69 +93,65 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" } - }, - boost: 10.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" } - }, - boost: 9.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "three" - } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" } - }, - boost: 5.0 - } + } + }, + boost: 5.0 } - ] - } + } + ] } } - }, - weight: 10.0, - normalizer: "minmax" - } + } + }, + weight: 10.0, + normalizer: "minmax" }, { - component: { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 } ] @@ -184,86 +176,82 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" } - }, - boost: 10.0 # normalized score for this would be -0.55 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } + } + }, + boost: 10.0 # normalized score for this would be -0.55 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" } - }, - boost: 9.0 # normalized score for this would be -0.56 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "three" - } + } + }, + boost: 9.0 # normalized score for this would be -0.56 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" } - }, - boost: 5.0 # normalized score for this would be -0.63 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "four" - } + } + }, + boost: 5.0 # normalized score for this would be -0.63 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "four" } - }, - boost: 1.0 # normalized score for this would be -0.7 - } + } + }, + boost: 1.0 # normalized score for this would be -0.7 } - ] - } + } + ] } } - }, - normalizer: { - minmax: { - min: 42, - max: 100 - } + } + }, + normalizer: { + minmax: { + min: 42, + max: 100 } } }, { # this only provides a score of 10 for doc 4 - component: { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 10.0 - } + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 10.0 } ] @@ -288,41 +276,37 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" } - }, - boost: 10.0 - } + } + }, + boost: 10.0 } } - }, - weight: 1.0, - normalizer: { - aardvark: { } } + }, + weight: 1.0, + normalizer: { + aardvark: { } } }, { - component: { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 } ] @@ -336,69 +320,65 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" } - }, - boost: 10.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" } - }, - boost: 9.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "three" - } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" } - }, - boost: 5.0 - } + } + }, + boost: 5.0 } - ] - } + } + ] } } - }, - weight: 10.0, - normalizer: "minmax" - } + } + }, + weight: 10.0, + normalizer: "minmax" }, { - component: { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 } ] from: 2 @@ -417,69 +397,65 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" } - }, - boost: 10.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" } - }, - boost: 9.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "three" - } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" } - }, - boost: 5.0 - } + } + }, + boost: 5.0 } - ] - } + } + ] } } - }, - weight: 10.0, - normalizer: "minmax" - } + } + }, + weight: 10.0, + normalizer: "minmax" }, { - component: { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 } ] from: 3 @@ -500,30 +476,26 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - match_all: { } - } + retriever: { + standard: { + query: { + match_all: { } } - }, - weight: 10.0, - normalizer: "minmax" - } + } + }, + weight: 10.0, + normalizer: "minmax" }, { - component: { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 } ] rank_window_size: 2 @@ -538,69 +510,65 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" } - }, - boost: 10.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" } - }, - boost: 9.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "three" - } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" } - }, - boost: 5.0 - } + } + }, + boost: 5.0 } - ] - } + } + ] } } - }, - weight: 1.0, - normalizer: "minmax" - } + } + }, + weight: 1.0, + normalizer: "minmax" }, { - component: { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 } ] rank_window_size: 2 @@ -622,58 +590,54 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" } - }, - boost: 10.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "four" - } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "four" } - }, - boost: 1.0 - } + } + }, + boost: 1.0 } - ] - } - }, - _name: "my_standard_retriever" - } - }, - weight: 10.0, - normalizer: "minmax" - } + } + ] + } + }, + _name: "my_standard_retriever" + } + }, + weight: 10.0, + normalizer: "minmax" }, { - component: { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 20.0 - } + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 20.0 } ] explain: true @@ -702,38 +666,34 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" } - }, - boost: 10.0 - } + } + }, + boost: 10.0 } } - }, - weight: 0.5 - } + } + }, + weight: 0.5 }, { - component: { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 } ] collapse: @@ -762,69 +722,60 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" } - }, - boost: 10.0 - } + } + }, + boost: 10.0 } } - }, - weight: 0.5 - } + } + }, + weight: 0.5 }, { - component: { - retriever: { - linear: { - retrievers: [ - { - component: { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } - } - }, - boost: 20.0 + retriever: { + linear: { + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } } - } + }, + boost: 20.0 } } } - }, - { - component: - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - } - } } - ] - } - }, - weight: 2.0 - } + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + } + } + ] + } + }, + weight: 2.0 } ] @@ -846,38 +797,34 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" } - }, - boost: 10.0 - } + } + }, + boost: 10.0 } } - }, - weight: 0.5 - } + } + }, + weight: 0.5 }, { - component: { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 } ] filter: @@ -900,43 +847,39 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" } - }, - boost: 10.0 - } - }, - filter: { - term: { - keyword: "four" - } + } + }, + boost: 10.0 + } + }, + filter: { + term: { + keyword: "four" } } - }, - weight: 0.5 - } + } + }, + weight: 0.5 }, { - component: { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 } ] @@ -956,43 +899,41 @@ setup: linear: retrievers: [ { - component: { - retriever: { - standard: { - query: { - constant_score: { - filter: { - bool: { - should: [ - { - term: { - keyword: { - value: "one" # this will give doc 1 a normalized score of 10 - } + retriever: { + standard: { + query: { + constant_score: { + filter: { + bool: { + should: [ + { + term: { + keyword: { + value: "one" # this will give doc 1 a normalized score of 10 } - }, - { - term: { - keyword: { - value: "two" # this will give doc 2 a normalized score of 10 - } + } + }, + { + term: { + keyword: { + value: "two" # this will give doc 2 a normalized score of 10 } - } ] - } - }, - boost: 10.0 - } - }, - sort: { - timestamp: { - order: "asc" - } + } + } ] + } + }, + boost: 10.0 + } + }, + sort: { + timestamp: { + order: "asc" } } - }, - weight: 1.0, - normalizer: "minmax" - } + } + }, + weight: 1.0, + normalizer: "minmax" }, { # because we're sorting on timestamp and use a rank window size of 2, we will only get to see @@ -1004,76 +945,74 @@ setup: # will be: # normalized(doc3) = 0.59989 # normalized(doc2) = 0.40010 - component: { - retriever: { - standard: { - query: { - function_score: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } + retriever: { + standard: { + query: { + function_score: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" } - }, - boost: 10.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" } - }, - boost: 9.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "three" - } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" } - }, - boost: 1.0 - } + } + }, + boost: 1.0 } - ] - } - }, - functions: [ { - script_score: { - script: { - source: "doc['timestamp'].value.millis" } + ] + } + }, + functions: [ { + script_score: { + script: { + source: "doc['timestamp'].value.millis" } - } ], - "boost_mode": "replace" - } - }, - sort: { - timestamp: { - order: "desc" - } + } + } ], + "boost_mode": "replace" + } + }, + sort: { + timestamp: { + order: "desc" } } - }, - weight: 1.0, - normalizer: { - minmax: { - min: 1577836800000, # 2020-01-01T00:00:00 - max: 1735689600000 # 2025-01-01T00:00:00 - } + } + }, + weight: 1.0, + normalizer: { + minmax: { + min: 1577836800000, # 2020-01-01T00:00:00 + max: 1735689600000 # 2025-01-01T00:00:00 } } } diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/IdentityScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/retriever/IdentityScoreNormalizer.java similarity index 97% rename from server/src/main/java/org/elasticsearch/search/normalizer/IdentityScoreNormalizer.java rename to server/src/main/java/org/elasticsearch/search/retriever/IdentityScoreNormalizer.java index 9a5c7af0c390f..68f0507ff3397 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/IdentityScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/IdentityScoreNormalizer.java @@ -7,7 +7,7 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.search.normalizer; +package org.elasticsearch.search.retriever; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.xcontent.ConstructingObjectParser; diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java index d71c93f5b3cae..8b429cd5733fe 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.search.normalizer.ScoreNormalizer; import org.elasticsearch.search.rank.LinearRankDoc; import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankDoc; @@ -47,8 +46,6 @@ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder { - p.nextToken(); - LinearRetrieverComponent retrieverBuilder = LinearRetrieverComponent.fromXContent(p, c); - p.nextToken(); - return retrieverBuilder; - }, RETRIEVERS_FIELD); + PARSER.declareObjectArray(constructorArg(), LinearRetrieverComponent::fromXContent, RETRIEVERS_FIELD); PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); RetrieverBuilder.declareBaseParserFields(NAME, PARSER); } @@ -145,7 +137,6 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { for (int rank = 0; rank < topResults.length; ++rank) { topResults[rank] = sortedResults[rank]; topResults[rank].rank = rank + 1; - topResults[rank].score = Math.max(EPSILON * (topResults.length - rank), topResults[rank].score); } return topResults; } @@ -161,7 +152,6 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept builder.startArray(RETRIEVERS_FIELD.getPreferredName()); for (var entry : innerRetrievers) { builder.startObject(); - builder.startObject(LinearRetrieverComponent.NAME); builder.field(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName(), entry.retriever()); builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), weights[index]); builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index]); diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java index 33b706292bb5e..1012b76ca10ed 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java @@ -9,10 +9,7 @@ package org.elasticsearch.search.retriever; -import org.elasticsearch.search.normalizer.IdentityScoreNormalizer; -import org.elasticsearch.search.normalizer.ScoreNormalizer; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.common.ParsingException; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -20,9 +17,8 @@ import java.io.IOException; +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.elasticsearch.search.retriever.LinearRetrieverBuilder.RETRIEVERS_FIELD; -import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; -import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; public class LinearRetrieverComponent implements ToXContentObject { @@ -33,49 +29,15 @@ public class LinearRetrieverComponent implements ToXContentObject { static final float DEFAULT_WEIGHT = 1f; static final ScoreNormalizer DEFAULT_NORMALIZER = IdentityScoreNormalizer.INSTANCE; - public static final String NAME = "component"; - - static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - NAME, - false, - args -> { - RetrieverBuilder base = (RetrieverBuilder) args[0]; - float weight = args[1] == null ? DEFAULT_WEIGHT : (float) args[1]; - ScoreNormalizer normalizer = args[2] == null ? DEFAULT_NORMALIZER : (ScoreNormalizer) args[2]; - return new LinearRetrieverComponent(base, weight, normalizer); - } - ); - - static { - PARSER.declareNamedObject(constructorArg(), (p, c, n) -> { - RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, n, c); - c.trackRetrieverUsage(retrieverBuilder.getName()); - return retrieverBuilder; - }, RETRIEVER_FIELD); - PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD); - PARSER.declareField(optionalConstructorArg(), (p, c) -> { - if (p.currentToken() == XContentParser.Token.VALUE_STRING) { - final String normalizer = p.text(); - return ScoreNormalizer.valueOf(normalizer); - } else if (p.currentToken() == XContentParser.Token.START_OBJECT) { - p.nextToken(); - final String normalizerName = p.currentName(); - ScoreNormalizer scoreNormalizer = ScoreNormalizer.parse(normalizerName, p); - p.nextToken(); - return scoreNormalizer; - } - throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); - }, NORMALIZER_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING); - } - RetrieverBuilder retriever; float weight; ScoreNormalizer normalizer; - public LinearRetrieverComponent(RetrieverBuilder base, float weight, ScoreNormalizer normalizer) { - this.retriever = base; - this.weight = weight; - this.normalizer = normalizer; + public LinearRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight, ScoreNormalizer normalizer) { + assert retrieverBuilder != null; + this.retriever = retrieverBuilder; + this.weight = weight == null ? DEFAULT_WEIGHT : weight; + this.normalizer = normalizer == null ? DEFAULT_NORMALIZER : normalizer; } @Override @@ -85,6 +47,46 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } public static LinearRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { - return PARSER.apply(parser, context); + RetrieverBuilder retrieverBuilder = null; + Float weight = null; + ScoreNormalizer normalizer = null; + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { + if (RETRIEVER_FIELD.match(parser.currentName(), parser.getDeprecationHandler())) { + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + parser.nextToken(); + ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.currentToken(), parser); + final String retrieverName = parser.currentName(); + parser.nextToken(); + retrieverBuilder = parser.namedObject(RetrieverBuilder.class, retrieverName, context); + parser.nextToken(); + } else if (WEIGHT_FIELD.match(parser.currentName(), parser.getDeprecationHandler())) { + parser.nextToken(); + weight = parser.floatValue(); + } else if (NORMALIZER_FIELD.match(parser.currentName(), parser.getDeprecationHandler())) { + parser.nextToken(); + if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { + normalizer = ScoreNormalizer.valueOf(parser.text()); + } else if (parser.currentToken() == XContentParser.Token.START_OBJECT) { + parser.nextToken(); + normalizer = ScoreNormalizer.parse(parser.currentName(), parser); + parser.nextToken(); + } else { + throw new ParsingException(parser.getTokenLocation(), "Unsupported token [" + parser.currentToken() + "]"); + } + } + } else { + throw new ParsingException( + parser.getTokenLocation(), + "Expected [" + XContentParser.Token.FIELD_NAME + "] but got [" + parser.currentToken() + "] instead." + ); + } + } + if (retrieverBuilder == null) { + throw new IllegalArgumentException("Missing required field [" + RETRIEVER_FIELD.getPreferredName() + "]"); + } + return new LinearRetrieverComponent(retrieverBuilder, weight, normalizer); } } diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/retriever/MinMaxScoreNormalizer.java similarity index 98% rename from server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java rename to server/src/main/java/org/elasticsearch/search/retriever/MinMaxScoreNormalizer.java index fa7c323f3b67f..f79516e03dacd 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/MinMaxScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/MinMaxScoreNormalizer.java @@ -7,7 +7,7 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.search.normalizer; +package org.elasticsearch.search.retriever; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.xcontent.ConstructingObjectParser; diff --git a/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/retriever/ScoreNormalizer.java similarity index 98% rename from server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java rename to server/src/main/java/org/elasticsearch/search/retriever/ScoreNormalizer.java index d12f8e4497b08..ed1b0f21b785c 100644 --- a/server/src/main/java/org/elasticsearch/search/normalizer/ScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/ScoreNormalizer.java @@ -7,7 +7,7 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.search.normalizer; +package org.elasticsearch.search.retriever; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.xcontent.ToXContent; diff --git a/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java index bfcc627c40773..4b0d1e7ea62cf 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java @@ -11,8 +11,6 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.search.SearchModule; -import org.elasticsearch.search.normalizer.MinMaxScoreNormalizer; -import org.elasticsearch.search.normalizer.ScoreNormalizer; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.xcontent.NamedXContentRegistry; From 86db0bc0aded9b5251b7fa18047ce04dc2a9282a Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Fri, 17 Jan 2025 11:51:07 +0200 Subject: [PATCH 27/57] fix test --- .../elasticsearch/search/retriever/LinearRetrieverBuilder.java | 1 - 1 file changed, 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java index 8b429cd5733fe..1e8b236f96bab 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java @@ -156,7 +156,6 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), weights[index]); builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index]); builder.endObject(); - builder.endObject(); index++; } builder.endArray(); From cc2c0718119ec0186a877c11d73fbaab1099a6a4 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Sun, 19 Jan 2025 21:28:32 +0200 Subject: [PATCH 28/57] iter --- .../index/query/QueryRewriteContext.java | 9 ++++++ .../search/builder/SearchSourceBuilder.java | 1 + .../search/rank/LinearRankDoc.java | 32 ++++++++++++------- .../retriever/CompoundRetrieverBuilder.java | 5 ++- .../retriever/LinearRetrieverBuilder.java | 13 +++++--- .../retriever/RescorerRetrieverBuilder.java | 2 +- .../TestCompoundRetrieverBuilder.java | 2 +- .../retriever/QueryRuleRetrieverBuilder.java | 2 +- .../TextSimilarityRankRetrieverBuilder.java | 2 +- .../xpack/rank/rrf/RRFRetrieverBuilder.java | 2 +- 10 files changed, 47 insertions(+), 23 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java index 265a0c52593bd..93de0b9531c56 100644 --- a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java @@ -72,6 +72,7 @@ public class QueryRewriteContext { private final ResolvedIndices resolvedIndices; private final PointInTimeBuilder pit; private QueryRewriteInterceptor queryRewriteInterceptor; + protected boolean isExplain; public QueryRewriteContext( final XContentParserConfiguration parserConfiguration, @@ -262,6 +263,14 @@ public void setMapUnmappedFieldAsString(boolean mapUnmappedFieldAsString) { this.mapUnmappedFieldAsString = mapUnmappedFieldAsString; } + public boolean isExplain() { + return this.isExplain; + } + + public void isExplain(boolean explain) { + this.isExplain = explain; + } + public NamedWriteableRegistry getWriteableRegistry() { return writeableRegistry; } diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index 6d47493e4d063..aa4ac577b8da3 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -1168,6 +1168,7 @@ public SearchSourceBuilder rewrite(QueryRewriteContext context) throws IOExcepti highlightBuilder ) )); + context.isExplain(explain()); if (retrieverBuilder != null) { var newRetriever = retrieverBuilder.rewrite(context); if (newRetriever != retrieverBuilder) { diff --git a/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java b/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java index 7b4ca79b1dc48..ad7b82a3431d9 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java @@ -32,15 +32,19 @@ public LinearRankDoc(int doc, float score, int shardIndex, float[] weights, Stri super(doc, score, shardIndex); this.weights = weights; this.normalizers = normalizers; - this.normalizedScores = new float[normalizers.length]; - Arrays.fill(normalizedScores, 0f); + if (normalizers == null) { + this.normalizedScores = null; + } else { + this.normalizedScores = new float[normalizers.length]; + Arrays.fill(normalizedScores, 0f); + } } public LinearRankDoc(StreamInput in) throws IOException { super(in); - weights = in.readFloatArray(); - normalizedScores = in.readFloatArray(); - normalizers = in.readStringArray(); + weights = in.readOptionalFloatArray(); + normalizedScores = in.readOptionalFloatArray(); + normalizers = in.readOptionalStringArray(); } @Override @@ -87,16 +91,22 @@ public Explanation explain(Explanation[] sources, String[] queryNames) { @Override protected void doWriteTo(StreamOutput out) throws IOException { - out.writeFloatArray(weights); - out.writeFloatArray(normalizedScores); - out.writeStringArray(normalizers); + out.writeOptionalFloatArray(weights); + out.writeOptionalFloatArray(normalizedScores); + out.writeOptionalStringArray(normalizers); } @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("weights", weights); - builder.field("normalizedScores", normalizedScores); - builder.field("normalizers", normalizers); + if (weights != null) { + builder.field("weights", weights); + } + if (normalizedScores != null) { + builder.field("normalizedScores", normalizedScores); + } + if (normalizers != null) { + builder.field("normalizers", normalizers); + } } @Override diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index ffd3bfc796f80..7588208a64e67 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -78,7 +78,7 @@ public T addChild(RetrieverBuilder retrieverBuilder) { /** * Combines the provided {@code rankResults} to return the final top documents. */ - protected abstract RankDoc[] combineInnerRetrieverResults(List rankResults); + protected abstract RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain); @Override public final boolean isCompound() { @@ -181,7 +181,7 @@ public void onResponse(MultiSearchResponse items) { failures.forEach(ex::addSuppressed); listener.onFailure(ex); } else { - results.set(combineInnerRetrieverResults(topDocs)); + results.set(combineInnerRetrieverResults(topDocs, ctx.isExplain())); listener.onResponse(null); } } @@ -192,7 +192,6 @@ public void onFailure(Exception e) { } }); }); - return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get); } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java index 1e8b236f96bab..63ded096c9d53 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java @@ -102,8 +102,11 @@ protected LinearRetrieverBuilder clone(List newChildRetrievers, } @Override - protected RankDoc[] combineInnerRetrieverResults(List rankResults) { + protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); + final String[] normalizerNames = false == isExplain + ? null + : Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new); for (int result = 0; result < rankResults.size(); result++) { ScoreDoc[] originalScoreDocs = rankResults.get(result); ScoreDoc[] normalizedScoreDocs = normalizers[result].normalizeScores(originalScoreDocs); @@ -118,11 +121,13 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { originalScoreDocs[finalScoreIndex].doc, 0f, originalScoreDocs[finalScoreIndex].shardIndex, - weights, - Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new) + false == isExplain ? null : weights, + normalizerNames ); } - value.normalizedScores[finalResult] = normalizedScoreDocs[finalScoreIndex].score; + if (isExplain) { + value.normalizedScores[finalResult] = normalizedScoreDocs[finalScoreIndex].score; + } value.score += weights[finalResult] * normalizedScoreDocs[finalScoreIndex].score; return value; } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java index 09688b5b9b001..9ef43e247672c 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java @@ -148,7 +148,7 @@ protected RescorerRetrieverBuilder clone(List newChildRetriever } @Override - protected RankDoc[] combineInnerRetrieverResults(List rankResults) { + protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { assert rankResults.size() == 1; ScoreDoc[] scoreDocs = rankResults.getFirst(); RankDoc[] rankDocs = new RankDoc[scoreDocs.length]; diff --git a/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java b/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java index 4a5f280c10a99..e08f68f82b824 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java +++ b/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java @@ -38,7 +38,7 @@ protected TestCompoundRetrieverBuilder clone(List newChildRetri } @Override - protected RankDoc[] combineInnerRetrieverResults(List rankResults) { + protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { return new RankDoc[0]; } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java index 528204f4132ea..5e1cb22c53813 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java @@ -164,7 +164,7 @@ protected QueryRuleRetrieverBuilder clone(List newChildRetrieve } @Override - protected RankDoc[] combineInnerRetrieverResults(List rankResults) { + protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { assert rankResults.size() == 1; ScoreDoc[] scoreDocs = rankResults.getFirst(); RankDoc[] rankDocs = new RuleQueryRankDoc[scoreDocs.length]; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 42248d246d3da..a77748d107715 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -136,7 +136,7 @@ protected TextSimilarityRankRetrieverBuilder clone( } @Override - protected RankDoc[] combineInnerRetrieverResults(List rankResults) { + protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { assert rankResults.size() == 1; ScoreDoc[] scoreDocs = rankResults.getFirst(); TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length]; diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index a749a7c402c30..e5c58fef6a994 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -105,7 +105,7 @@ protected RRFRetrieverBuilder clone(List newRetrievers, List rankResults) { + protected RRFRankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { // combine the disjointed sets of TopDocs into a single set or RRFRankDocs // each RRFRankDoc will have both the position and score for each query where // it was within the result set for that query From 90ef7f30bdf15de245f7b0f0aa8517e75d139c4d Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Sun, 19 Jan 2025 21:33:56 +0200 Subject: [PATCH 29/57] addressing PR comments - adding exception for unknown tokens during parsing --- .../search/retriever/LinearRetrieverComponent.java | 5 +++++ .../search/retriever/MinMaxScoreNormalizer.java | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java index 1012b76ca10ed..8b3e4ce18d11b 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java @@ -76,6 +76,11 @@ public static LinearRetrieverComponent fromXContent(XContentParser parser, Retri } else { throw new ParsingException(parser.getTokenLocation(), "Unsupported token [" + parser.currentToken() + "]"); } + } else { + throw new ParsingException( + parser.getTokenLocation(), + "Unexpected token [" + parser.currentToken() + "] for linear retriever." + ); } } else { throw new ParsingException( diff --git a/server/src/main/java/org/elasticsearch/search/retriever/MinMaxScoreNormalizer.java b/server/src/main/java/org/elasticsearch/search/retriever/MinMaxScoreNormalizer.java index f79516e03dacd..bfd0ba35cbca0 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/MinMaxScoreNormalizer.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/MinMaxScoreNormalizer.java @@ -90,9 +90,10 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { if (min > max) { throw new IllegalArgumentException("[min=" + min + "] must be less than [max=" + max + "]"); } + boolean minEqualsMax = min.equals(max); for (int i = 0; i < docs.length; i++) { float score; - if (min.equals(max)) { + if (minEqualsMax) { score = min; } else { score = correction + (docs[i].score - min) / (max - min); From 7d6feeda9f664a62b0c6ee2396366e4d7983da4d Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Sun, 19 Jan 2025 21:47:07 +0200 Subject: [PATCH 30/57] iter --- .../org/elasticsearch/search/builder/SearchSourceBuilder.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index aa4ac577b8da3..251cd76345f88 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -1168,7 +1168,7 @@ public SearchSourceBuilder rewrite(QueryRewriteContext context) throws IOExcepti highlightBuilder ) )); - context.isExplain(explain()); + context.isExplain(explain() != null && explain()); if (retrieverBuilder != null) { var newRetriever = retrieverBuilder.rewrite(context); if (newRetriever != retrieverBuilder) { From 512952df429d74cf39d89a42146d4614e3a3143a Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 21 Jan 2025 22:05:19 +0200 Subject: [PATCH 31/57] reverting optimization to avoid populating rank docs for explain, as it will be handled in another PR --- .../index/query/QueryRewriteContext.java | 9 ------ .../search/builder/SearchSourceBuilder.java | 1 - .../search/rank/LinearRankDoc.java | 31 ++++++------------- .../retriever/CompoundRetrieverBuilder.java | 4 +-- .../retriever/LinearRetrieverBuilder.java | 12 +++---- .../retriever/RescorerRetrieverBuilder.java | 2 +- .../TestCompoundRetrieverBuilder.java | 2 +- .../retriever/QueryRuleRetrieverBuilder.java | 2 +- .../TextSimilarityRankRetrieverBuilder.java | 2 +- .../xpack/rank/rrf/RRFRetrieverBuilder.java | 2 +- 10 files changed, 21 insertions(+), 46 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java index 93de0b9531c56..265a0c52593bd 100644 --- a/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/QueryRewriteContext.java @@ -72,7 +72,6 @@ public class QueryRewriteContext { private final ResolvedIndices resolvedIndices; private final PointInTimeBuilder pit; private QueryRewriteInterceptor queryRewriteInterceptor; - protected boolean isExplain; public QueryRewriteContext( final XContentParserConfiguration parserConfiguration, @@ -263,14 +262,6 @@ public void setMapUnmappedFieldAsString(boolean mapUnmappedFieldAsString) { this.mapUnmappedFieldAsString = mapUnmappedFieldAsString; } - public boolean isExplain() { - return this.isExplain; - } - - public void isExplain(boolean explain) { - this.isExplain = explain; - } - public NamedWriteableRegistry getWriteableRegistry() { return writeableRegistry; } diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index 251cd76345f88..6d47493e4d063 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -1168,7 +1168,6 @@ public SearchSourceBuilder rewrite(QueryRewriteContext context) throws IOExcepti highlightBuilder ) )); - context.isExplain(explain() != null && explain()); if (retrieverBuilder != null) { var newRetriever = retrieverBuilder.rewrite(context); if (newRetriever != retrieverBuilder) { diff --git a/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java b/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java index ad7b82a3431d9..f6b3d09afbdcb 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java @@ -32,19 +32,14 @@ public LinearRankDoc(int doc, float score, int shardIndex, float[] weights, Stri super(doc, score, shardIndex); this.weights = weights; this.normalizers = normalizers; - if (normalizers == null) { - this.normalizedScores = null; - } else { - this.normalizedScores = new float[normalizers.length]; - Arrays.fill(normalizedScores, 0f); - } + this.normalizedScores = new float[normalizers.length]; } public LinearRankDoc(StreamInput in) throws IOException { super(in); - weights = in.readOptionalFloatArray(); - normalizedScores = in.readOptionalFloatArray(); - normalizers = in.readOptionalStringArray(); + weights = in.readFloatArray(); + normalizedScores = in.readFloatArray(); + normalizers = in.readStringArray(); } @Override @@ -91,22 +86,16 @@ public Explanation explain(Explanation[] sources, String[] queryNames) { @Override protected void doWriteTo(StreamOutput out) throws IOException { - out.writeOptionalFloatArray(weights); - out.writeOptionalFloatArray(normalizedScores); - out.writeOptionalStringArray(normalizers); + out.writeFloatArray(weights); + out.writeFloatArray(normalizedScores); + out.writeStringArray(normalizers); } @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { - if (weights != null) { - builder.field("weights", weights); - } - if (normalizedScores != null) { - builder.field("normalizedScores", normalizedScores); - } - if (normalizers != null) { - builder.field("normalizers", normalizers); - } + builder.field("weights", weights); + builder.field("normalizedScores", normalizedScores); + builder.field("normalizers", normalizers); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 7588208a64e67..830367d9479fa 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -78,7 +78,7 @@ public T addChild(RetrieverBuilder retrieverBuilder) { /** * Combines the provided {@code rankResults} to return the final top documents. */ - protected abstract RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain); + protected abstract RankDoc[] combineInnerRetrieverResults(List rankResults); @Override public final boolean isCompound() { @@ -181,7 +181,7 @@ public void onResponse(MultiSearchResponse items) { failures.forEach(ex::addSuppressed); listener.onFailure(ex); } else { - results.set(combineInnerRetrieverResults(topDocs, ctx.isExplain())); + results.set(combineInnerRetrieverResults(topDocs)); listener.onResponse(null); } } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java index 63ded096c9d53..6d02e840cb4dd 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java @@ -102,11 +102,9 @@ protected LinearRetrieverBuilder clone(List newChildRetrievers, } @Override - protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { + protected RankDoc[] combineInnerRetrieverResults(List rankResults) { Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); - final String[] normalizerNames = false == isExplain - ? null - : Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new); + final String[] normalizerNames = Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new); for (int result = 0; result < rankResults.size(); result++) { ScoreDoc[] originalScoreDocs = rankResults.get(result); ScoreDoc[] normalizedScoreDocs = normalizers[result].normalizeScores(originalScoreDocs); @@ -121,13 +119,11 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b originalScoreDocs[finalScoreIndex].doc, 0f, originalScoreDocs[finalScoreIndex].shardIndex, - false == isExplain ? null : weights, + weights, normalizerNames ); } - if (isExplain) { - value.normalizedScores[finalResult] = normalizedScoreDocs[finalScoreIndex].score; - } + value.normalizedScores[finalResult] = normalizedScoreDocs[finalScoreIndex].score; value.score += weights[finalResult] * normalizedScoreDocs[finalScoreIndex].score; return value; } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java index 9ef43e247672c..09688b5b9b001 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java @@ -148,7 +148,7 @@ protected RescorerRetrieverBuilder clone(List newChildRetriever } @Override - protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { + protected RankDoc[] combineInnerRetrieverResults(List rankResults) { assert rankResults.size() == 1; ScoreDoc[] scoreDocs = rankResults.getFirst(); RankDoc[] rankDocs = new RankDoc[scoreDocs.length]; diff --git a/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java b/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java index e08f68f82b824..4a5f280c10a99 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java +++ b/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java @@ -38,7 +38,7 @@ protected TestCompoundRetrieverBuilder clone(List newChildRetri } @Override - protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { + protected RankDoc[] combineInnerRetrieverResults(List rankResults) { return new RankDoc[0]; } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java index 5e1cb22c53813..528204f4132ea 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java @@ -164,7 +164,7 @@ protected QueryRuleRetrieverBuilder clone(List newChildRetrieve } @Override - protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { + protected RankDoc[] combineInnerRetrieverResults(List rankResults) { assert rankResults.size() == 1; ScoreDoc[] scoreDocs = rankResults.getFirst(); RankDoc[] rankDocs = new RuleQueryRankDoc[scoreDocs.length]; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index a77748d107715..42248d246d3da 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -136,7 +136,7 @@ protected TextSimilarityRankRetrieverBuilder clone( } @Override - protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { + protected RankDoc[] combineInnerRetrieverResults(List rankResults) { assert rankResults.size() == 1; ScoreDoc[] scoreDocs = rankResults.getFirst(); TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length]; diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index e5c58fef6a994..a749a7c402c30 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -105,7 +105,7 @@ protected RRFRetrieverBuilder clone(List newRetrievers, List rankResults, boolean isExplain) { + protected RRFRankDoc[] combineInnerRetrieverResults(List rankResults) { // combine the disjointed sets of TopDocs into a single set or RRFRankDocs // each RRFRankDoc will have both the position and score for each query where // it was within the result set for that query From 4f97a819865d6d54898fb594aadc1f71b331c67c Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 22 Jan 2025 10:42:06 +0200 Subject: [PATCH 32/57] moving linear retriever to xpack and adding integ tests --- .../elasticsearch/search/SearchModule.java | 2 - .../search/retriever/RetrieversFeatures.java | 2 +- .../LinearRetrieverBuilderParsingTests.java | 8 +- .../xpack/rank/linear/LinearRetrieverIT.java | 763 ++++++++++++ .../rank-rrf/src/main/java/module-info.java | 6 +- .../RRFFeatures.java => RankRRFFeatures.java} | 14 +- .../rank/linear}/LinearRetrieverBuilder.java | 46 +- .../linear}/LinearRetrieverComponent.java | 16 +- .../xpack/rank/rrf/RRFRankPlugin.java | 11 +- ...lasticsearch.features.FeatureSpecification | 2 +- .../rrf/LinearRankClientYamlTestSuiteIT.java | 45 + .../test/linear/10_linear_retriever.yml | 1028 +++++++++++++++++ 12 files changed, 1913 insertions(+), 30 deletions(-) rename {server/src/test/java/org/elasticsearch/search/retriever => x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear}/LinearRetrieverBuilderParsingTests.java (89%) create mode 100644 x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java rename x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/{rrf/RRFFeatures.java => RankRRFFeatures.java} (65%) rename {server/src/main/java/org/elasticsearch/search/retriever => x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear}/LinearRetrieverBuilder.java (81%) rename {server/src/main/java/org/elasticsearch/search/retriever => x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear}/LinearRetrieverComponent.java (88%) create mode 100644 x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/LinearRankClientYamlTestSuiteIT.java create mode 100644 x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index ec3f64fc9a591..bbc89564326dc 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -232,7 +232,6 @@ import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.retriever.KnnRetrieverBuilder; -import org.elasticsearch.search.retriever.LinearRetrieverBuilder; import org.elasticsearch.search.retriever.RescorerRetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; @@ -1086,7 +1085,6 @@ private void registerRetrieverParsers(List plugins) { registerRetriever(new RetrieverSpec<>(StandardRetrieverBuilder.NAME, StandardRetrieverBuilder::fromXContent)); registerRetriever(new RetrieverSpec<>(KnnRetrieverBuilder.NAME, KnnRetrieverBuilder::fromXContent)); registerRetriever(new RetrieverSpec<>(RescorerRetrieverBuilder.NAME, RescorerRetrieverBuilder::fromXContent)); - registerRetriever(new RetrieverSpec<>(LinearRetrieverBuilder.NAME, LinearRetrieverBuilder::fromXContent)); registerFromPlugin(plugins, SearchPlugin::getRetrievers, this::registerRetriever); } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java index c94d845938db7..bfd6f572a9e65 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieversFeatures.java @@ -22,6 +22,6 @@ public class RetrieversFeatures implements FeatureSpecification { @Override public Set getFeatures() { - return Set.of(LinearRetrieverBuilder.LINEAR_RETRIEVER_SUPPORTED); + return Set.of(); } } diff --git a/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java similarity index 89% rename from server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java rename to x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java index 4b0d1e7ea62cf..0f0c2ea0a7a44 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/LinearRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java @@ -7,10 +7,16 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.search.retriever; +package org.elasticsearch.xpack.rank.linear; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.MinMaxScoreNormalizer; +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.retriever.ScoreNormalizer; +import org.elasticsearch.search.retriever.TestRetrieverBuilder; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.xcontent.NamedXContentRegistry; diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java new file mode 100644 index 0000000000000..527b224a454b7 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java @@ -0,0 +1,763 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.linear; + +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.InnerHitBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.bucket.terms.Terms; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.collapse.CollapseBuilder; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.retriever.StandardRetrieverBuilder; +import org.elasticsearch.search.retriever.TestRetrieverBuilder; +import org.elasticsearch.search.sort.FieldSortBuilder; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.search.vectors.QueryVectorBuilder; +import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +@ESIntegTestCase.ClusterScope(minNumDataNodes = 2) +public class LinearRetrieverIT extends ESIntegTestCase { + + protected static String INDEX = "test_index"; + protected static final String DOC_FIELD = "doc"; + protected static final String TEXT_FIELD = "text"; + protected static final String VECTOR_FIELD = "vector"; + protected static final String TOPIC_FIELD = "topic"; + + @Override + protected Collection> nodePlugins() { + return List.of(RRFRankPlugin.class); + } + + @Before + public void setup() throws Exception { + setupIndex(); + } + + protected void setupIndex() { + String mapping = """ + { + "properties": { + "vector": { + "type": "dense_vector", + "dims": 1, + "element_type": "float", + "similarity": "l2_norm", + "index": true, + "index_options": { + "type": "hnsw" + } + }, + "text": { + "type": "text" + }, + "doc": { + "type": "keyword" + }, + "topic": { + "type": "keyword" + }, + "views": { + "type": "nested", + "properties": { + "last30d": { + "type": "integer" + }, + "all": { + "type": "integer" + } + } + } + } + } + """; + createIndex(INDEX, Settings.builder().put(SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5)).build()); + admin().indices().preparePutMapping(INDEX).setSource(mapping, XContentType.JSON).get(); + indexDoc(INDEX, "doc_1", DOC_FIELD, "doc_1", TOPIC_FIELD, "technology", TEXT_FIELD, "term"); + indexDoc( + INDEX, + "doc_2", + DOC_FIELD, + "doc_2", + TOPIC_FIELD, + "astronomy", + TEXT_FIELD, + "search term term", + VECTOR_FIELD, + new float[]{2.0f} + ); + indexDoc(INDEX, "doc_3", DOC_FIELD, "doc_3", TOPIC_FIELD, "technology", VECTOR_FIELD, new float[]{3.0f}); + indexDoc(INDEX, "doc_4", DOC_FIELD, "doc_4", TOPIC_FIELD, "technology", TEXT_FIELD, "term term term term"); + indexDoc(INDEX, "doc_5", DOC_FIELD, "doc_5", TOPIC_FIELD, "science", TEXT_FIELD, "irrelevant stuff"); + indexDoc( + INDEX, + "doc_6", + DOC_FIELD, + "doc_6", + TEXT_FIELD, + "search term term term term term term", + VECTOR_FIELD, + new float[]{6.0f} + ); + indexDoc( + INDEX, + "doc_7", + DOC_FIELD, + "doc_7", + TOPIC_FIELD, + "biology", + TEXT_FIELD, + "term term term term term term term", + VECTOR_FIELD, + new float[]{7.0f} + ); + refresh(INDEX); + } + + + public void testLinearRetrieverWithAggs() { + final int rankWindowSize = 100; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{2.0f}, null, 10, 100, null, null); + + // all requests would have an equal weight and use the identity normalizer + source.retriever( + new org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize + ) + ); + source.size(1); + source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD)); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(1)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + + assertNotNull(resp.getAggregations()); + assertNotNull(resp.getAggregations().get("topic_agg")); + Terms terms = resp.getAggregations().get("topic_agg"); + + assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L)); + assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L)); + assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L)); + }); + } + + public void testLinearWithCollapse() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{2.0f}, null, 10, 100, null, null); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize + ) + ); + source.collapse( + new CollapseBuilder(TOPIC_FIELD).setInnerHits( + new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10) + ) + ); + source.fetchField(TOPIC_FIELD); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(4)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_7")); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); + assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); + }); + } + + public void testRRFRetrieverWithCollapseAndAggs() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{2.0f}, null, 10, 100, null, null); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize + ) + ); + source.collapse( + new CollapseBuilder(TOPIC_FIELD).setInnerHits( + new InnerHitBuilder("a").addSort(new FieldSortBuilder(DOC_FIELD).order(SortOrder.DESC)).setSize(10) + ) + ); + source.fetchField(TOPIC_FIELD); + source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD)); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(4)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_7")); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); + assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); + + assertNotNull(resp.getAggregations()); + assertNotNull(resp.getAggregations().get("topic_agg")); + Terms terms = resp.getAggregations().get("topic_agg"); + + assertThat(terms.getBucketByKey("technology").getDocCount(), equalTo(3L)); + assertThat(terms.getBucketByKey("astronomy").getDocCount(), equalTo(1L)); + assertThat(terms.getBucketByKey("biology").getDocCount(), equalTo(1L)); + }); + } + + public void testMultipleRRFRetrievers() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{2.0f}, null, 10, 100, null, null); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource( + // this one returns docs 6, 7, 1, 3, and 4 + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize + ), + null + ), + // this one bring just doc 7 which should be ranked first eventually + new CompoundRetrieverBuilder.RetrieverSource( + new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{7.0f}, null, 1, 100, null, null), + null + ) + ), + rankWindowSize + ) + ); + + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_7")); + assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(5).getId(), equalTo("doc_4")); + }); + } + + public void testRRFExplainWithNamedRetrievers() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + standard0.retrieverName("my_custom_retriever"); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{2.0f}, null, 10, 100, null, null); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize + ) + ); + source.explain(true); + source.size(1); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(1)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true)); + assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:")); + assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2)); + var rrfDetails = resp.getHits().getAt(0).getExplanation().getDetails()[0]; + assertThat(rrfDetails.getDetails().length, equalTo(3)); + assertThat(rrfDetails.getDescription(), containsString("computed for initial ranks [2, 1, 1]")); + + assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [2] in query at index [0]")); + assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("[my_custom_retriever]")); + assertThat(rrfDetails.getDetails()[1].getDescription(), containsString("for rank [1] in query at index [1]")); + assertThat(rrfDetails.getDetails()[2].getDescription(), containsString("for rank [1] in query at index [2]")); + }); + } + + public void testRRFExplainWithAnotherNestedRRF() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this one retrieves docs 1, 2, 4, 6, and 7 + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(9L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_4")).boost(8L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(7L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) + ); + standard0.retrieverName("my_custom_retriever"); + // this one retrieves docs 2 and 6 due to prefilter + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + // this one retrieves docs 2, 3, 6, and 7 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{2.0f}, null, 10, 100, null, null); + + LinearRetrieverBuilder nestedRRF = new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + ), + rankWindowSize + ); + StandardRetrieverBuilder standard2 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(20L) + ); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(nestedRRF, null), + new CompoundRetrieverBuilder.RetrieverSource(standard2, null) + ), + rankWindowSize + ) + ); + source.explain(true); + source.size(1); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); + assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getHits().length, equalTo(1)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true)); + assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:")); + assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2)); + var rrfTopLevel = resp.getHits().getAt(0).getExplanation().getDetails()[0]; + assertThat(rrfTopLevel.getDetails().length, equalTo(2)); + assertThat(rrfTopLevel.getDescription(), containsString("computed for initial ranks [2, 1]")); + assertThat(rrfTopLevel.getDetails()[0].getDetails()[0].getDescription(), containsString("rrf score")); + assertThat(rrfTopLevel.getDetails()[1].getDetails()[0].getDescription(), containsString("ConstantScore")); + + var rrfDetails = rrfTopLevel.getDetails()[0].getDetails()[0]; + assertThat(rrfDetails.getDetails().length, equalTo(3)); + assertThat(rrfDetails.getDescription(), containsString("computed for initial ranks [4, 2, 3]")); + + assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [4] in query at index [0]")); + assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [4] in query at index [0]")); + assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("[my_custom_retriever]")); + assertThat(rrfDetails.getDetails()[1].getDescription(), containsString("for rank [2] in query at index [1]")); + assertThat(rrfDetails.getDetails()[2].getDescription(), containsString("for rank [3] in query at index [2]")); + }); + } + + public void testRRFInnerRetrieverAll4xxSearchErrors() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this will throw a 4xx error during evaluation + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10)) + ); + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null) + ), + rankWindowSize + ) + ); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + Exception ex = expectThrows(ElasticsearchStatusException.class, req::get); + assertThat(ex, instanceOf(ElasticsearchStatusException.class)); + assertThat( + ex.getMessage(), + containsString( + "[rrf] search failed - retrievers '[standard]' returned errors. All failures are attached as suppressed exceptions." + ) + ); + assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.BAD_REQUEST)); + assertThat(ex.getSuppressed().length, equalTo(1)); + assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class)); + } + + public void testRRFInnerRetrieverMultipleErrorsOne5xx() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this will throw a 4xx error during evaluation + StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( + QueryBuilders.constantScoreQuery(QueryBuilders.rangeQuery(VECTOR_FIELD).gte(10)) + ); + // this will throw a 5xx error + TestRetrieverBuilder testRetrieverBuilder = new TestRetrieverBuilder("val") { + @Override + public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + searchSourceBuilder.aggregation(AggregationBuilders.avg("some_invalid_param")); + } + }; + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standard0, null), + new CompoundRetrieverBuilder.RetrieverSource(testRetrieverBuilder, null) + ), + rankWindowSize + ) + ); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + Exception ex = expectThrows(ElasticsearchStatusException.class, req::get); + assertThat(ex, instanceOf(ElasticsearchStatusException.class)); + assertThat( + ex.getMessage(), + containsString( + "[rrf] search failed - retrievers '[standard, test]' returned errors. All failures are attached as suppressed exceptions." + ) + ); + assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.INTERNAL_SERVER_ERROR)); + assertThat(ex.getSuppressed().length, equalTo(2)); + assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class)); + assertThat(ex.getSuppressed()[1].getCause().getCause(), instanceOf(IllegalStateException.class)); + } + + public void testRRFInnerRetrieverErrorWhenExtractingToSource() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") { + @Override + public QueryBuilder topDocsQuery() { + return QueryBuilders.matchAllQuery(); + } + + @Override + public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + throw new UnsupportedOperationException("simulated failure"); + } + }; + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null) + ), + rankWindowSize + ) + ); + source.size(1); + expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get()); + } + + public void testRRFInnerRetrieverErrorOnTopDocs() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") { + @Override + public QueryBuilder topDocsQuery() { + throw new UnsupportedOperationException("simulated failure"); + } + + @Override + public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + searchSourceBuilder.query(QueryBuilders.matchAllQuery()); + } + }; + StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( + QueryBuilders.boolQuery() + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_3")).boost(10L)) + .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) + ); + standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(failingRetriever, null), + new CompoundRetrieverBuilder.RetrieverSource(standard1, null) + ), + rankWindowSize + ) + ); + source.size(1); + source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD)); + expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get()); + } + + public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this will retriever all but 7 only due to top-level filter + StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery()); + // this will too retrieve just doc 7 + KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder( + "vector", + null, + new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[]{3}), + 10, + 10, + null, + null + ); + source.retriever( + new LinearRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null) + ), + rankWindowSize + ) + ); + source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7"))); + source.size(10); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(1L)); + assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7")); + }); + } + + public void testRewriteOnce() { + final float[] vector = new float[]{1}; + AtomicInteger numAsyncCalls = new AtomicInteger(); + QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() { + @Override + public void buildVector(Client client, ActionListener listener) { + numAsyncCalls.incrementAndGet(); + listener.onResponse(vector); + } + + @Override + public String getWriteableName() { + throw new IllegalStateException("Should not be called"); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + throw new IllegalStateException("Should not be called"); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IllegalStateException("Should not be called"); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + throw new IllegalStateException("Should not be called"); + } + }; + var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null); + var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null)); + var rrf = new LinearRetrieverBuilder( + List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)), + 10 + ); + assertResponse( + client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)), + searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value(), is(4L)) + ); + assertThat(numAsyncCalls.get(), equalTo(2)); + + // check that we use the rewritten vector to build the explain query + assertResponse( + client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)), + searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value(), is(4L)) + ); + assertThat(numAsyncCalls.get(), equalTo(4)); + } +} diff --git a/x-pack/plugin/rank-rrf/src/main/java/module-info.java b/x-pack/plugin/rank-rrf/src/main/java/module-info.java index 4fd2a7e4d54f3..fbe467fdf3eae 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/module-info.java +++ b/x-pack/plugin/rank-rrf/src/main/java/module-info.java @@ -5,7 +5,7 @@ * 2.0. */ -import org.elasticsearch.xpack.rank.rrf.RRFFeatures; +import org.elasticsearch.xpack.rank.RankRRFFeatures; module org.elasticsearch.rank.rrf { requires org.apache.lucene.core; @@ -14,7 +14,9 @@ requires org.elasticsearch.server; requires org.elasticsearch.xcore; + exports org.elasticsearch.xpack.rank; exports org.elasticsearch.xpack.rank.rrf; + exports org.elasticsearch.xpack.rank.linear; - provides org.elasticsearch.features.FeatureSpecification with RRFFeatures; + provides org.elasticsearch.features.FeatureSpecification with RankRRFFeatures; } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java similarity index 65% rename from x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java rename to x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java index 494eaa508c14a..5966e17f20429 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.rank.rrf; +package org.elasticsearch.xpack.rank; import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.NodeFeature; @@ -14,10 +14,14 @@ import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT; -/** - * A set of features specifically for the rrf plugin. - */ -public class RRFFeatures implements FeatureSpecification { +public class RankRRFFeatures implements FeatureSpecification { + + public static final NodeFeature LINEAR_RETRIEVER_SUPPORTED = new NodeFeature("linear_retriever_supported"); + + @Override + public Set getFeatures() { + return Set.of(LINEAR_RETRIEVER_SUPPORTED); + } @Override public Set getTestFeatures() { diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java similarity index 81% rename from server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java rename to x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index 6d02e840cb4dd..89999e873c651 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -1,26 +1,31 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ -package org.elasticsearch.search.retriever; +package org.elasticsearch.xpack.rank.linear; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.util.Maps; -import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.search.rank.LinearRankDoc; import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.IdentityScoreNormalizer; +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.retriever.ScoreNormalizer; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.XPackPlugin; +import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin; import java.io.IOException; import java.util.ArrayList; @@ -30,6 +35,8 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.rank.RankRRFFeatures.LINEAR_RETRIEVER_SUPPORTED; +import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT; /** * The {@code LinearRetrieverBuilder} supports the combination of different retrievers through a weighted linear combination. @@ -43,7 +50,6 @@ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder innerRetrievers, + int rankWindowSize + ) { + this(innerRetrievers, rankWindowSize, null, null); + } + public LinearRetrieverBuilder( List innerRetrievers, int rankWindowSize, @@ -90,8 +106,18 @@ public LinearRetrieverBuilder( ScoreNormalizer[] normalizers ) { super(innerRetrievers, rankWindowSize); - this.weights = weights; - this.normalizers = normalizers; + if (weights == null) { + this.weights = new float[innerRetrievers.size()]; + Arrays.fill(this.weights, DEFAULT_WEIGHT); + } else { + this.weights = weights; + } + if (normalizers == null) { + this.normalizers = new ScoreNormalizer[innerRetrievers.size()]; + Arrays.fill(this.normalizers, IdentityScoreNormalizer.INSTANCE); + } else { + this.normalizers = normalizers; + } } @Override diff --git a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java similarity index 88% rename from server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java rename to x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java index 8b3e4ce18d11b..168c94ddda6c9 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/LinearRetrieverComponent.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java @@ -1,15 +1,17 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ -package org.elasticsearch.search.retriever; +package org.elasticsearch.xpack.rank.linear; import org.elasticsearch.common.ParsingException; +import org.elasticsearch.search.retriever.IdentityScoreNormalizer; +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.retriever.ScoreNormalizer; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -18,7 +20,7 @@ import java.io.IOException; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.elasticsearch.search.retriever.LinearRetrieverBuilder.RETRIEVERS_FIELD; +import static org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder.RETRIEVERS_FIELD; public class LinearRetrieverComponent implements ToXContentObject { diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java index 9404d863f1d28..8d19337a0974d 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java @@ -17,6 +17,7 @@ import org.elasticsearch.search.rank.RankShardResult; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder; import java.util.List; @@ -28,6 +29,12 @@ public class RRFRankPlugin extends Plugin implements SearchPlugin { License.OperationMode.ENTERPRISE ); + public static final LicensedFeature.Momentary LINEAR_RETRIEVER_FEATURE = LicensedFeature.momentary( + null, + "linear-retriever", + License.OperationMode.ENTERPRISE + ); + public static final String NAME = "rrf"; @Override @@ -46,6 +53,8 @@ public List getNamedXContent() { @Override public List> getRetrievers() { - return List.of(new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent)); + return List.of( + new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent), + new RetrieverSpec<>(new ParseField(LinearRetrieverBuilder.NAME), LinearRetrieverBuilder::fromXContent)); } } diff --git a/x-pack/plugin/rank-rrf/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification b/x-pack/plugin/rank-rrf/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification index 605e999b66c66..528b7e35bee65 100644 --- a/x-pack/plugin/rank-rrf/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification +++ b/x-pack/plugin/rank-rrf/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification @@ -5,4 +5,4 @@ # 2.0. # -org.elasticsearch.xpack.rank.rrf.RRFFeatures +org.elasticsearch.xpack.rank.RankRRFFeatures diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/LinearRankClientYamlTestSuiteIT.java b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/LinearRankClientYamlTestSuiteIT.java new file mode 100644 index 0000000000000..8af4ae307a51a --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/LinearRankClientYamlTestSuiteIT.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.rrf; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; +import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; +import org.junit.ClassRule; + +/** Runs yaml rest tests. */ +public class LinearRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase { + + @ClassRule + public static ElasticsearchCluster cluster = ElasticsearchCluster.local() + .nodes(2) + .module("mapper-extras") + .module("rank-rrf") + .module("lang-painless") + .module("x-pack-inference") + .setting("xpack.license.self_generated.type", "trial") + .plugin("inference-service-test") + .build(); + + public LinearRankClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { + super(testCandidate); + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return ESClientYamlSuiteTestCase.createParameters(new String[] { "linear" }); + } + + @Override + protected String getTestRestCluster() { + return cluster.getHttpAddresses(); + } +} diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml new file mode 100644 index 0000000000000..d82d2f4d41c4f --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml @@ -0,0 +1,1028 @@ +setup: + - requires: + cluster_features: [ "linear_retriever_supported" ] + reason: "Support for linear retriever" + test_runner_features: close_to + + - do: + indices.create: + index: test + body: + mappings: + properties: + vector: + type: dense_vector + dims: 1 + index: true + similarity: l2_norm + keyword: + type: keyword + other_keyword: + type: keyword + timestamp: + type: date + + - do: + bulk: + refresh: true + index: test + body: + - '{"index": {"_id": 1 }}' + - '{"vector": [1], "keyword": "one", "other_keyword": "other", "timestamp": "2021-01-01T00:00:00"}' + - '{"index": {"_id": 2 }}' + - '{"vector": [2], "keyword": "two", "timestamp": "2022-01-01T00:00:00"}' + - '{"index": {"_id": 3 }}' + - '{"vector": [3], "keyword": "three", "timestamp": "2023-01-01T00:00:00"}' + - '{"index": {"_id": 4 }}' + - '{"vector": [4], "keyword": "four", "other_keyword": "other", "timestamp": "2024-01-01T00:00:00"}' + +--- +"basic linear weighted combination of a standard and knn retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 0.5 + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 5.0 } + - match: { hits.hits.1._id: "4" } + - match: { hits.hits.1._score: 2.0 } + +--- +"should normalize initial scores": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + } + ] + } + } + } + }, + weight: 10.0, + normalizer: "minmax" + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "1" } + - match: {hits.hits.0._score: 10.0} + - match: { hits.hits.1._id: "2" } + - match: {hits.hits.1._score: 8.0} + - match: { hits.hits.2._id: "4" } + - match: {hits.hits.2._score: 2.0} + - match: { hits.hits.2._score: 2.0 } + - match: { hits.hits.3._id: "3" } + - close_to: { hits.hits.3._score: { value: 0.0, error: 0.001 } } + +--- +"should normalize initial scores with a custom minmax normalizer": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 # normalized score for this would be -0.55 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 # normalized score for this would be -0.56 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 # normalized score for this would be -0.63 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "four" + } + } + }, + boost: 1.0 # normalized score for this would be -0.7 + } + } + ] + } + } + } + }, + normalizer: { + minmax: { + min: 42, + max: 100 + } + } + }, + { + # this only provides a score of 10 for doc 4 + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 10.0 + } + ] + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "4" } + - close_to: { hits.hits.0._score: { value: 50.2931, error: 0.001 } } + - match: { hits.hits.1._id: "1" } + - close_to: { hits.hits.1._score: { value: 40.4482, error: 0.001 } } + - match: { hits.hits.2._id: "2" } + - close_to: { hits.hits.2._score: { value: 40.4310, error: 0.001 } } + - match: { hits.hits.3._id: "3" } + - close_to: { hits.hits.3._score: { value: 40.3620, error: 0.001 } } + +--- +"should throw on unknown normalizer": + - do: + catch: /Unknown normalizer \[aardvark\]/ + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 1.0, + normalizer: { + aardvark: { } + } + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + +--- +"pagination within a consistent rank_window_size": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + } + ] + } + } + } + }, + weight: 10.0, + normalizer: "minmax" + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + from: 2 + size: 1 + + - match: { hits.total.value: 4 } + - length: { hits.hits: 1 } + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._score: 2.0 } + + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + } + ] + } + } + } + }, + weight: 10.0, + normalizer: "minmax" + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + from: 3 + size: 1 + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "3" } + - close_to: { hits.hits.0._score: { value: 0.0, error: 0.001 } } + +--- +"should throw when rank_window_size less than size": + - do: + catch: "/\\[linear\\] requires \\[rank_window_size: 2\\] be greater than or equal to \\[size: 10\\]/" + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + match_all: { } + } + } + }, + weight: 10.0, + normalizer: "minmax" + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + rank_window_size: 2 + size: 10 +--- +"should respect rank_window_size for normalization and returned hits": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + } + ] + } + } + } + }, + weight: 1.0, + normalizer: "minmax" + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + rank_window_size: 2 + size: 2 + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._score: 2.0 } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1._score: 1.0 } + +--- +"explain should provide info on weights and inner retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "four" + } + } + }, + boost: 1.0 + } + } + ] + } + }, + _name: "my_standard_retriever" + } + }, + weight: 10.0, + normalizer: "minmax" + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 20.0 + } + ] + explain: true + size: 2 + + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._explanation.description: "/weighted.linear.combination.score:.\\[20.0].computed.for.normalized.scores.\\[.*,.1.0\\].and.weights.\\[10.0,.20.0\\].as.sum.of.\\(weight\\[i\\].*.score\\[i\\]\\).for.each.query./"} + - match: { hits.hits.0._explanation.details.0.value: 0.0 } + - match: { hits.hits.0._explanation.details.0.description: "/.*weighted.score.*result.not.found.in.query.at.index.\\[0\\].\\[my_standard_retriever\\]/" } + - match: { hits.hits.0._explanation.details.1.value: 20.0 } + - match: { hits.hits.0._explanation.details.1.description: "/.*weighted.score.*using.score.normalizer.\\[none\\].*/" } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1._explanation.description: "/weighted.linear.combination.score:.\\[10.0].computed.for.normalized.scores.\\[1.0,.0.0\\].and.weights.\\[10.0,.20.0\\].as.sum.of.\\(weight\\[i\\].*.score\\[i\\]\\).for.each.query./"} + - match: { hits.hits.1._explanation.details.0.value: 10.0 } + - match: { hits.hits.1._explanation.details.0.description: "/.*weighted.score.*\\[my_standard_retriever\\].*using.score.normalizer.\\[minmax\\].*/" } + - match: { hits.hits.1._explanation.details.1.value: 0.0 } + - match: { hits.hits.1._explanation.details.1.description: "/.*weighted.score.*result.not.found.in.query.at.index.\\[1\\]/" } + +--- +"collapsing results": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 0.5 + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + collapse: + field: other_keyword + inner_hits: { + name: sub_hits, + sort: + { + keyword: { + order: desc + } + } + } + - match: { hits.hits.0._id: "1" } + - length: { hits.hits.0.inner_hits.sub_hits.hits.hits : 2 } + - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.0._id: "1" } + - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.1._id: "4" } + +--- +"multiple nested linear retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 0.5 + }, + { + retriever: { + linear: { + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 20.0 + } + } + } + } + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + } + } + ] + } + }, + weight: 2.0 + } + ] + + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.0._score: 40.0 } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1._score: 5.0 } + - match: { hits.hits.2._id: "4" } + - match: { hits.hits.2._score: 2.0 } + +--- +"linear retriever with filters": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 0.5 + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + filter: + term: + keyword: "four" + + + - match: { hits.total.value: 1 } + - length: {hits.hits: 1} + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._score: 2.0 } + +--- +"linear retriever with filters on nested retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + filter: { + term: { + keyword: "four" + } + } + } + }, + weight: 0.5 + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: 2.0 + } + ] + + - match: { hits.total.value: 1 } + - length: {hits.hits: 1} + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.0._score: 2.0 } + + +--- +"linear retriever with custom sort and score for nested retrievers": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + bool: { + should: [ + { + term: { + keyword: { + value: "one" # this will give doc 1 a normalized score of 10 + } + } + }, + { + term: { + keyword: { + value: "two" # this will give doc 2 a normalized score of 10 + } + } + } ] + } + }, + boost: 10.0 + } + }, + sort: { + timestamp: { + order: "asc" + } + } + } + }, + weight: 1.0, + normalizer: "minmax" + }, + { + # because we're sorting on timestamp and use a rank window size of 2, we will only get to see + # docs 3 and 2. + # their `scores` (which are the timestamps) are: + # doc 3: 1672531200000 (2023-01-01T00:00:00) + # doc 2: 1640995200000 (2022-01-01T00:00:00) + # and their normalized scores based on the provided conf + # will be: + # normalized(doc3) = 0.59989 + # normalized(doc2) = 0.40010 + retriever: { + standard: { + query: { + function_score: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 9.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 1.0 + } + } + ] + } + }, + functions: [ { + script_score: { + script: { + source: "doc['timestamp'].value.millis" + } + } + } ], + "boost_mode": "replace" + } + }, + sort: { + timestamp: { + order: "desc" + } + } + } + }, + weight: 1.0, + normalizer: { + minmax: { + min: 1577836800000, # 2020-01-01T00:00:00 + max: 1735689600000 # 2025-01-01T00:00:00 + } + } + } + ] + rank_window_size: 2 + size: 2 + + - match: { hits.total.value: 3 } + - length: {hits.hits: 2} + - match: { hits.hits.0._id: "2" } + - close_to: { hits.hits.0._score: { value: 10.4001, error: 0.001 } } + - match: { hits.hits.1._id: "1" } + - match: { hits.hits.1._score: 10 } From 629491793b745e79fcf3453bac19385ddc6e5c23 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 22 Jan 2025 10:44:07 +0200 Subject: [PATCH 33/57] fixing license --- .../rank/linear/LinearRetrieverBuilderParsingTests.java | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java index 0f0c2ea0a7a44..f868d60ea20bf 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java @@ -1,10 +1,8 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ package org.elasticsearch.xpack.rank.linear; From 78fdcdaa12154610f6a122eb48e0f13189dd2e4c Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 22 Jan 2025 12:57:56 +0200 Subject: [PATCH 34/57] adding integ tests --- .../index/query/RankDocsQueryBuilder.java | 4 +- .../elasticsearch/search/SearchModule.java | 2 - .../retriever/CompoundRetrieverBuilder.java | 8 +- .../retriever/RankDocsRetrieverBuilder.java | 5 +- .../retriever/RescorerRetrieverBuilder.java | 1 + .../rank/linear}/LinearRankDocTests.java | 11 +- .../xpack/rank/linear/LinearRetrieverIT.java | 237 ++++++++++++------ .../xpack/rank/linear}/LinearRankDoc.java | 11 +- .../rank/linear/LinearRetrieverBuilder.java | 13 +- .../xpack/rank/rrf/RRFRankPlugin.java | 7 +- .../xpack/rank/rrf/RRFRetrieverBuilder.java | 1 + 11 files changed, 193 insertions(+), 107 deletions(-) rename {server/src/test/java/org/elasticsearch/search/rank => x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear}/LinearRankDocTests.java (87%) rename {server/src/main/java/org/elasticsearch/search/rank => x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear}/LinearRankDoc.java (90%) diff --git a/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java index 889fa40b79aa1..524310c547597 100644 --- a/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java @@ -70,7 +70,9 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws changed |= newQueryBuilders[i] != queryBuilders[i]; } if (changed) { - return new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs); + RankDocsQueryBuilder clone = new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs); + clone.queryName(queryName()); + return clone; } } return super.doRewrite(queryRewriteContext); diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index bbc89564326dc..6716c03a3a935 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -224,7 +224,6 @@ import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.search.fetch.subphase.highlight.PlainHighlighter; import org.elasticsearch.search.internal.ShardSearchRequest; -import org.elasticsearch.search.rank.LinearRankDoc; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.rank.RankShardResult; import org.elasticsearch.search.rank.feature.RankFeatureDoc; @@ -835,7 +834,6 @@ private void registerRescorer(RescorerSpec spec) { private void registerRankers() { namedWriteables.add(new NamedWriteableRegistry.Entry(RankDoc.class, RankDoc.NAME, RankDoc::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(RankDoc.class, RankFeatureDoc.NAME, RankFeatureDoc::new)); - namedWriteables.add(new NamedWriteableRegistry.Entry(RankDoc.class, LinearRankDoc.NAME, LinearRankDoc::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(RankShardResult.class, RankFeatureShardResult.NAME, RankFeatureShardResult::new) ); diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 830367d9479fa..902a05b8e5c91 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -192,7 +192,13 @@ public void onFailure(Exception e) { } }); }); - return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get); + RankDocsRetrieverBuilder rankDocsRetrieverBuilder = new RankDocsRetrieverBuilder( + rankWindowSize, + newRetrievers.stream().map(s -> s.retriever).toList(), + results::get + ); + rankDocsRetrieverBuilder.retrieverName(retrieverName()); + return rankDocsRetrieverBuilder; } @Override diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java index f873da8c71506..a77f5327fbc26 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java @@ -90,11 +90,13 @@ public QueryBuilder topDocsQuery() { @Override public QueryBuilder explainQuery() { - return new RankDocsQueryBuilder( + var explainQuery = new RankDocsQueryBuilder( rankDocs.get(), sources.stream().map(RetrieverBuilder::explainQuery).toArray(QueryBuilder[]::new), true ); + explainQuery.queryName(retrieverName()); + return explainQuery; } @Override @@ -123,6 +125,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder } else { rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false); } + rankQuery.queryName(retrieverName()); // ignore prefilters of this level, they were already propagated to children searchSourceBuilder.query(rankQuery); if (searchSourceBuilder.size() < 0) { diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java index 09688b5b9b001..6b2a44fdfe106 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java @@ -144,6 +144,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept protected RescorerRetrieverBuilder clone(List newChildRetrievers, List newPreFilterQueryBuilders) { var newInstance = new RescorerRetrieverBuilder(newChildRetrievers.get(0), rescorers); newInstance.preFilterQueryBuilders = newPreFilterQueryBuilders; + newInstance.retrieverName = retrieverName; return newInstance; } diff --git a/server/src/test/java/org/elasticsearch/search/rank/LinearRankDocTests.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java similarity index 87% rename from server/src/test/java/org/elasticsearch/search/rank/LinearRankDocTests.java rename to x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java index cbc3ef551139d..b7bef70ce3024 100644 --- a/server/src/test/java/org/elasticsearch/search/rank/LinearRankDocTests.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java @@ -1,16 +1,15 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ -package org.elasticsearch.search.rank; +package org.elasticsearch.xpack.rank.linear; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase; import java.io.IOException; import java.util.Collections; diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java index 527b224a454b7..99514cdbdc337 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java @@ -26,7 +26,9 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.IdentityScoreNormalizer; import org.elasticsearch.search.retriever.KnnRetrieverBuilder; +import org.elasticsearch.search.retriever.ScoreNormalizer; import org.elasticsearch.search.retriever.StandardRetrieverBuilder; import org.elasticsearch.search.retriever.TestRetrieverBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; @@ -50,6 +52,7 @@ import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -84,7 +87,7 @@ protected void setupIndex() { "similarity": "l2_norm", "index": true, "index_options": { - "type": "hnsw" + "type": "flat" } }, "text": { @@ -123,9 +126,9 @@ protected void setupIndex() { TEXT_FIELD, "search term term", VECTOR_FIELD, - new float[]{2.0f} + new float[] { 2.0f } ); - indexDoc(INDEX, "doc_3", DOC_FIELD, "doc_3", TOPIC_FIELD, "technology", VECTOR_FIELD, new float[]{3.0f}); + indexDoc(INDEX, "doc_3", DOC_FIELD, "doc_3", TOPIC_FIELD, "technology", VECTOR_FIELD, new float[] { 3.0f }); indexDoc(INDEX, "doc_4", DOC_FIELD, "doc_4", TOPIC_FIELD, "technology", TEXT_FIELD, "term term term term"); indexDoc(INDEX, "doc_5", DOC_FIELD, "doc_5", TOPIC_FIELD, "science", TEXT_FIELD, "irrelevant stuff"); indexDoc( @@ -136,7 +139,7 @@ protected void setupIndex() { TEXT_FIELD, "search term term term term term term", VECTOR_FIELD, - new float[]{6.0f} + new float[] { 6.0f } ); indexDoc( INDEX, @@ -148,12 +151,11 @@ protected void setupIndex() { TEXT_FIELD, "term term term term term term term", VECTOR_FIELD, - new float[]{7.0f} + new float[] { 7.0f } ); refresh(INDEX); } - public void testLinearRetrieverWithAggs() { final int rankWindowSize = 100; SearchSourceBuilder source = new SearchSourceBuilder(); @@ -175,7 +177,7 @@ public void testLinearRetrieverWithAggs() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{2.0f}, null, 10, 100, null, null); + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); // all requests would have an equal weight and use the identity normalizer source.retriever( @@ -211,9 +213,9 @@ public void testLinearRetrieverWithAggs() { public void testLinearWithCollapse() { final int rankWindowSize = 100; - final int rankConstant = 10; SearchSourceBuilder source = new SearchSourceBuilder(); // this one retrieves docs 1, 2, 4, 6, and 7 + // with scores 10, 9, 8, 7, 6 StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( QueryBuilders.boolQuery() .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) @@ -223,6 +225,7 @@ public void testLinearWithCollapse() { .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) ); // this one retrieves docs 2 and 6 due to prefilter + // with scores 20, 5 StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( QueryBuilders.boolQuery() .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) @@ -231,7 +234,15 @@ public void testLinearWithCollapse() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{2.0f}, null, 10, 100, null, null); + // with scores 1, 0.5, 0.05882353, 0.03846154 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + // final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3 + // doc 1: 10 + // doc 2: 9 + 20 + 1 = 30 + // doc 3: 0.5 + // doc 4: 8 + // doc 6: 7 + 5 + 0.05882353 = 12.05882353 + // doc 7: 6 + 0.03846154 = 6.03846154 source.retriever( new LinearRetrieverBuilder( Arrays.asList( @@ -256,20 +267,24 @@ public void testLinearWithCollapse() { assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); assertThat(resp.getHits().getHits().length, equalTo(4)); assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + assertThat(resp.getHits().getAt(0).getScore(), equalTo(30f)); assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_7")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); - assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); + assertThat((double) resp.getHits().getAt(1).getScore(), closeTo(12.0588f, 0.0001f)); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(2).getScore(), equalTo(10f)); + assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); + assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7")); + assertThat((double) resp.getHits().getAt(3).getScore(), closeTo(6.0384f, 0.0001f)); }); } - public void testRRFRetrieverWithCollapseAndAggs() { + public void testLinearRetrieverWithCollapseAndAggs() { final int rankWindowSize = 100; - final int rankConstant = 10; SearchSourceBuilder source = new SearchSourceBuilder(); // this one retrieves docs 1, 2, 4, 6, and 7 + // with scores 10, 9, 8, 7, 6 StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( QueryBuilders.boolQuery() .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) @@ -279,6 +294,7 @@ public void testRRFRetrieverWithCollapseAndAggs() { .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) ); // this one retrieves docs 2 and 6 due to prefilter + // with scores 20, 5 StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( QueryBuilders.boolQuery() .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) @@ -287,7 +303,15 @@ public void testRRFRetrieverWithCollapseAndAggs() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{2.0f}, null, 10, 100, null, null); + // with scores 1, 0.5, 0.05882353, 0.03846154 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + // final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3 + // doc 1: 10 + // doc 2: 9 + 20 + 1 = 30 + // doc 3: 0.5 + // doc 4: 8 + // doc 6: 7 + 5 + 0.05882353 = 12.05882353 + // doc 7: 6 + 0.03846154 = 6.03846154 source.retriever( new LinearRetrieverBuilder( Arrays.asList( @@ -314,11 +338,11 @@ public void testRRFRetrieverWithCollapseAndAggs() { assertThat(resp.getHits().getHits().length, equalTo(4)); assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_7")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); - assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); - assertThat(resp.getHits().getAt(3).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(0).getId(), equalTo("doc_4")); + assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(1).getId(), equalTo("doc_3")); + assertThat(resp.getHits().getAt(2).getInnerHits().get("a").getAt(2).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_7")); assertNotNull(resp.getAggregations()); assertNotNull(resp.getAggregations().get("topic_agg")); @@ -330,11 +354,11 @@ public void testRRFRetrieverWithCollapseAndAggs() { }); } - public void testMultipleRRFRetrievers() { + public void testMultipleLinearRetrievers() { final int rankWindowSize = 100; - final int rankConstant = 10; SearchSourceBuilder source = new SearchSourceBuilder(); // this one retrieves docs 1, 2, 4, 6, and 7 + // with scores 10, 9, 8, 7, 6 StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( QueryBuilders.boolQuery() .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) @@ -344,6 +368,7 @@ public void testMultipleRRFRetrievers() { .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_7")).boost(6L)) ); // this one retrieves docs 2 and 6 due to prefilter + // with scores 20, 5 StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( QueryBuilders.boolQuery() .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) @@ -351,30 +376,32 @@ public void testMultipleRRFRetrievers() { .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(5L)) ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); - // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{2.0f}, null, 10, 100, null, null); source.retriever( new LinearRetrieverBuilder( Arrays.asList( new CompoundRetrieverBuilder.RetrieverSource( - // this one returns docs 6, 7, 1, 3, and 4 + // this one returns docs doc 2, 1, 6, 4, 7 + // with scores 38, 20, 19, 16, 12 new LinearRetrieverBuilder( Arrays.asList( new CompoundRetrieverBuilder.RetrieverSource(standard0, null), - new CompoundRetrieverBuilder.RetrieverSource(standard1, null), - new CompoundRetrieverBuilder.RetrieverSource(knnRetrieverBuilder, null) + new CompoundRetrieverBuilder.RetrieverSource(standard1, null) ), - rankWindowSize + rankWindowSize, + new float[] { 2.0f, 1.0f }, + null ), null ), - // this one bring just doc 7 which should be ranked first eventually + // this one bring just doc 7 which should be ranked first eventually with a score of 100 new CompoundRetrieverBuilder.RetrieverSource( - new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{7.0f}, null, 1, 100, null, null), + new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null), null ) ), - rankWindowSize + rankWindowSize, + new float[] { 1.0f, 100.0f }, + new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE } ) ); @@ -382,22 +409,26 @@ public void testMultipleRRFRetrievers() { ElasticsearchAssertions.assertResponse(req, resp -> { assertNull(resp.pointInTimeId()); assertNotNull(resp.getHits().getTotalHits()); - assertThat(resp.getHits().getTotalHits().value(), equalTo(6L)); + assertThat(resp.getHits().getTotalHits().value(), equalTo(5L)); assertThat(resp.getHits().getTotalHits().relation(), equalTo(TotalHits.Relation.EQUAL_TO)); assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_7")); + assertThat(resp.getHits().getAt(0).getScore(), equalTo(112f)); assertThat(resp.getHits().getAt(1).getId(), equalTo("doc_2")); - assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_6")); - assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_1")); - assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_3")); - assertThat(resp.getHits().getAt(5).getId(), equalTo("doc_4")); + assertThat(resp.getHits().getAt(1).getScore(), equalTo(38f)); + assertThat(resp.getHits().getAt(2).getId(), equalTo("doc_1")); + assertThat(resp.getHits().getAt(2).getScore(), equalTo(20f)); + assertThat(resp.getHits().getAt(3).getId(), equalTo("doc_6")); + assertThat(resp.getHits().getAt(3).getScore(), equalTo(19f)); + assertThat(resp.getHits().getAt(4).getId(), equalTo("doc_4")); + assertThat(resp.getHits().getAt(4).getScore(), equalTo(16f)); }); } - public void testRRFExplainWithNamedRetrievers() { + public void testLinearExplainWithNamedRetrievers() { final int rankWindowSize = 100; - final int rankConstant = 10; SearchSourceBuilder source = new SearchSourceBuilder(); // this one retrieves docs 1, 2, 4, 6, and 7 + // with scores 10, 9, 8, 7, 6 StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( QueryBuilders.boolQuery() .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) @@ -408,6 +439,7 @@ public void testRRFExplainWithNamedRetrievers() { ); standard0.retrieverName("my_custom_retriever"); // this one retrieves docs 2 and 6 due to prefilter + // with scores 20, 5 StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( QueryBuilders.boolQuery() .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) @@ -416,7 +448,15 @@ public void testRRFExplainWithNamedRetrievers() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{2.0f}, null, 10, 100, null, null); + // with scores 1, 0.5, 0.05882353, 0.03846154 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + // final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3 + // doc 1: 10 + // doc 2: 9 + 20 + 1 = 30 + // doc 3: 0.5 + // doc 4: 8 + // doc 6: 7 + 5 + 0.05882353 = 12.05882353 + // doc 7: 6 + 0.03846154 = 6.03846154 source.retriever( new LinearRetrieverBuilder( Arrays.asList( @@ -442,20 +482,39 @@ public void testRRFExplainWithNamedRetrievers() { assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2)); var rrfDetails = resp.getHits().getAt(0).getExplanation().getDetails()[0]; assertThat(rrfDetails.getDetails().length, equalTo(3)); - assertThat(rrfDetails.getDescription(), containsString("computed for initial ranks [2, 1, 1]")); - - assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [2] in query at index [0]")); - assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("[my_custom_retriever]")); - assertThat(rrfDetails.getDetails()[1].getDescription(), containsString("for rank [1] in query at index [1]")); - assertThat(rrfDetails.getDetails()[2].getDescription(), containsString("for rank [1] in query at index [2]")); + assertThat( + rrfDetails.getDescription(), + equalTo( + "weighted linear combination score: [30.0] computed for normalized scores [9.0, 20.0, 1.0] and weights [1.0, 1.0, 1.0] as sum of (weight[i] * score[i]) for each query." + ) + ); + + assertThat( + rrfDetails.getDetails()[0].getDescription(), + containsString( + "weighted score: [9.0] in query at index [0] [my_custom_retriever] computed as [1.0 * 9.0] using score normalizer [none] for original matching query with score" + ) + ); + assertThat( + rrfDetails.getDetails()[1].getDescription(), + containsString( + "weighted score: [20.0] in query at index [1] computed as [1.0 * 20.0] using score normalizer [none] for original matching query with score:" + ) + ); + assertThat( + rrfDetails.getDetails()[2].getDescription(), + containsString( + "weighted score: [1.0] in query at index [2] computed as [1.0 * 1.0] using score normalizer [none] for original matching query with score" + ) + ); }); } - public void testRRFExplainWithAnotherNestedRRF() { + public void testLinearExplainWithAnotherNestedLinear() { final int rankWindowSize = 100; - final int rankConstant = 10; SearchSourceBuilder source = new SearchSourceBuilder(); // this one retrieves docs 1, 2, 4, 6, and 7 + // with scores 10, 9, 8, 7, 6 StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( QueryBuilders.boolQuery() .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_1")).boost(10L)) @@ -466,6 +525,7 @@ public void testRRFExplainWithAnotherNestedRRF() { ); standard0.retrieverName("my_custom_retriever"); // this one retrieves docs 2 and 6 due to prefilter + // with scores 20, 5 StandardRetrieverBuilder standard1 = new StandardRetrieverBuilder( QueryBuilders.boolQuery() .should(QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_2")).boost(20L)) @@ -474,9 +534,16 @@ public void testRRFExplainWithAnotherNestedRRF() { ); standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD)); // this one retrieves docs 2, 3, 6, and 7 - KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[]{2.0f}, null, 10, 100, null, null); - - LinearRetrieverBuilder nestedRRF = new LinearRetrieverBuilder( + // with scores 1, 0.5, 0.05882353, 0.03846154 + KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null); + // final ranking with no-normalizer would be: doc 2, 6, 1, 4, 7, 3 + // doc 1: 10 + // doc 2: 9 + 20 + 1 = 30 + // doc 3: 0.5 + // doc 4: 8 + // doc 6: 7 + 5 + 0.05882353 = 12.05882353 + // doc 7: 6 + 0.03846154 = 6.03846154 + LinearRetrieverBuilder nestedLinear = new LinearRetrieverBuilder( Arrays.asList( new CompoundRetrieverBuilder.RetrieverSource(standard0, null), new CompoundRetrieverBuilder.RetrieverSource(standard1, null), @@ -484,16 +551,20 @@ public void testRRFExplainWithAnotherNestedRRF() { ), rankWindowSize ); + nestedLinear.retrieverName("nested_linear"); + // this one retrieves docs 6 with a score of 100 StandardRetrieverBuilder standard2 = new StandardRetrieverBuilder( QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds("doc_6")).boost(20L) ); source.retriever( new LinearRetrieverBuilder( Arrays.asList( - new CompoundRetrieverBuilder.RetrieverSource(nestedRRF, null), + new CompoundRetrieverBuilder.RetrieverSource(nestedLinear, null), new CompoundRetrieverBuilder.RetrieverSource(standard2, null) ), - rankWindowSize + rankWindowSize, + new float[] { 1, 5f }, + new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE } ) ); source.explain(true); @@ -509,27 +580,31 @@ public void testRRFExplainWithAnotherNestedRRF() { assertThat(resp.getHits().getAt(0).getExplanation().isMatch(), equalTo(true)); assertThat(resp.getHits().getAt(0).getExplanation().getDescription(), containsString("sum of:")); assertThat(resp.getHits().getAt(0).getExplanation().getDetails().length, equalTo(2)); - var rrfTopLevel = resp.getHits().getAt(0).getExplanation().getDetails()[0]; - assertThat(rrfTopLevel.getDetails().length, equalTo(2)); - assertThat(rrfTopLevel.getDescription(), containsString("computed for initial ranks [2, 1]")); - assertThat(rrfTopLevel.getDetails()[0].getDetails()[0].getDescription(), containsString("rrf score")); - assertThat(rrfTopLevel.getDetails()[1].getDetails()[0].getDescription(), containsString("ConstantScore")); - - var rrfDetails = rrfTopLevel.getDetails()[0].getDetails()[0]; - assertThat(rrfDetails.getDetails().length, equalTo(3)); - assertThat(rrfDetails.getDescription(), containsString("computed for initial ranks [4, 2, 3]")); - - assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [4] in query at index [0]")); - assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("for rank [4] in query at index [0]")); - assertThat(rrfDetails.getDetails()[0].getDescription(), containsString("[my_custom_retriever]")); - assertThat(rrfDetails.getDetails()[1].getDescription(), containsString("for rank [2] in query at index [1]")); - assertThat(rrfDetails.getDetails()[2].getDescription(), containsString("for rank [3] in query at index [2]")); + var linearTopLevel = resp.getHits().getAt(0).getExplanation().getDetails()[0]; + assertThat(linearTopLevel.getDetails().length, equalTo(2)); + assertThat( + linearTopLevel.getDescription(), + containsString( + "weighted linear combination score: [112.05882] computed for normalized scores [12.058824, 20.0] and weights [1.0, 5.0] as sum of (weight[i] * score[i]) for each query." + ) + ); + assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("weighted score: [12.058824]")); + assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("nested_linear")); + assertThat(linearTopLevel.getDetails()[1].getDescription(), containsString("weighted score: [100.0]")); + + var linearNested = linearTopLevel.getDetails()[0]; + assertThat(linearNested.getDetails()[0].getDetails().length, equalTo(3)); + assertThat(linearNested.getDetails()[0].getDetails()[0].getDescription(), containsString("weighted score: [7.0]")); + assertThat(linearNested.getDetails()[0].getDetails()[1].getDescription(), containsString("weighted score: [5.0]")); + assertThat(linearNested.getDetails()[0].getDetails()[2].getDescription(), containsString("weighted score: [0.05882353]")); + + var standard0Details = linearTopLevel.getDetails()[1]; + assertThat(standard0Details.getDetails()[0].getDescription(), containsString("ConstantScore")); }); } - public void testRRFInnerRetrieverAll4xxSearchErrors() { + public void testLinearInnerRetrieverAll4xxSearchErrors() { final int rankWindowSize = 100; - final int rankConstant = 10; SearchSourceBuilder source = new SearchSourceBuilder(); // this will throw a 4xx error during evaluation StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( @@ -557,7 +632,7 @@ public void testRRFInnerRetrieverAll4xxSearchErrors() { assertThat( ex.getMessage(), containsString( - "[rrf] search failed - retrievers '[standard]' returned errors. All failures are attached as suppressed exceptions." + "[linear] search failed - retrievers '[standard]' returned errors. All failures are attached as suppressed exceptions." ) ); assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.BAD_REQUEST)); @@ -565,9 +640,8 @@ public void testRRFInnerRetrieverAll4xxSearchErrors() { assertThat(ex.getSuppressed()[0].getCause().getCause(), instanceOf(IllegalArgumentException.class)); } - public void testRRFInnerRetrieverMultipleErrorsOne5xx() { + public void testLinearInnerRetrieverMultipleErrorsOne5xx() { final int rankWindowSize = 100; - final int rankConstant = 10; SearchSourceBuilder source = new SearchSourceBuilder(); // this will throw a 4xx error during evaluation StandardRetrieverBuilder standard0 = new StandardRetrieverBuilder( @@ -595,7 +669,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder assertThat( ex.getMessage(), containsString( - "[rrf] search failed - retrievers '[standard, test]' returned errors. All failures are attached as suppressed exceptions." + "[linear] search failed - retrievers '[standard, test]' returned errors. All failures are attached as suppressed exceptions." ) ); assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.INTERNAL_SERVER_ERROR)); @@ -604,9 +678,8 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder assertThat(ex.getSuppressed()[1].getCause().getCause(), instanceOf(IllegalStateException.class)); } - public void testRRFInnerRetrieverErrorWhenExtractingToSource() { + public void testLinearInnerRetrieverErrorWhenExtractingToSource() { final int rankWindowSize = 100; - final int rankConstant = 10; SearchSourceBuilder source = new SearchSourceBuilder(); TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") { @Override @@ -639,9 +712,8 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get()); } - public void testRRFInnerRetrieverErrorOnTopDocs() { + public void testLinearInnerRetrieverErrorOnTopDocs() { final int rankWindowSize = 100; - final int rankConstant = 10; SearchSourceBuilder source = new SearchSourceBuilder(); TestRetrieverBuilder failingRetriever = new TestRetrieverBuilder("some value") { @Override @@ -675,9 +747,8 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get()); } - public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() { + public void testLinearFiltersPropagatedToKnnQueryVectorBuilder() { final int rankWindowSize = 100; - final int rankConstant = 10; SearchSourceBuilder source = new SearchSourceBuilder(); // this will retriever all but 7 only due to top-level filter StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery()); @@ -685,7 +756,7 @@ public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() { KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder( "vector", null, - new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[]{3}), + new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }), 10, 10, null, @@ -712,7 +783,7 @@ public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() { } public void testRewriteOnce() { - final float[] vector = new float[]{1}; + final float[] vector = new float[] { 1 }; AtomicInteger numAsyncCalls = new AtomicInteger(); QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() { @Override diff --git a/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java similarity index 90% rename from server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java rename to x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java index f6b3d09afbdcb..19eab1dc19e73 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/LinearRankDoc.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java @@ -1,19 +1,18 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ -package org.elasticsearch.search.rank; +package org.elasticsearch.xpack.rank.linear; import org.apache.lucene.search.Explanation; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index 89999e873c651..65c0c8c41ef38 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -12,7 +12,6 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.license.LicenseUtils; -import org.elasticsearch.search.rank.LinearRankDoc; import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; @@ -92,10 +91,7 @@ public static LinearRetrieverBuilder fromXContent(XContentParser parser, Retriev return PARSER.apply(parser, context); } - LinearRetrieverBuilder( - List innerRetrievers, - int rankWindowSize - ) { + LinearRetrieverBuilder(List innerRetrievers, int rankWindowSize) { this(innerRetrievers, rankWindowSize, null, null); } @@ -106,6 +102,12 @@ public LinearRetrieverBuilder( ScoreNormalizer[] normalizers ) { super(innerRetrievers, rankWindowSize); + if (weights != null && weights.length != innerRetrievers.size()) { + throw new IllegalArgumentException("The number of weights must match the number of inner retrievers"); + } + if (normalizers != null && normalizers.length != innerRetrievers.size()) { + throw new IllegalArgumentException("The number of normalizers must match the number of inner retrievers"); + } if (weights == null) { this.weights = new float[innerRetrievers.size()]; Arrays.fill(this.weights, DEFAULT_WEIGHT); @@ -124,6 +126,7 @@ public LinearRetrieverBuilder( protected LinearRetrieverBuilder clone(List newChildRetrievers, List newPreFilterQueryBuilders) { LinearRetrieverBuilder clone = new LinearRetrieverBuilder(newChildRetrievers, rankWindowSize, weights, normalizers); clone.preFilterQueryBuilders = newPreFilterQueryBuilders; + clone.retrieverName = retrieverName; return clone; } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java index 8d19337a0974d..251015b21ff50 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankPlugin.java @@ -17,6 +17,7 @@ import org.elasticsearch.search.rank.RankShardResult; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xpack.rank.linear.LinearRankDoc; import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder; import java.util.List; @@ -42,7 +43,8 @@ public List getNamedWriteables() { return List.of( new NamedWriteableRegistry.Entry(RankBuilder.class, NAME, RRFRankBuilder::new), new NamedWriteableRegistry.Entry(RankShardResult.class, NAME, RRFRankShardResult::new), - new NamedWriteableRegistry.Entry(RankDoc.class, RRFRankDoc.NAME, RRFRankDoc::new) + new NamedWriteableRegistry.Entry(RankDoc.class, RRFRankDoc.NAME, RRFRankDoc::new), + new NamedWriteableRegistry.Entry(RankDoc.class, LinearRankDoc.NAME, LinearRankDoc::new) ); } @@ -55,6 +57,7 @@ public List getNamedXContent() { public List> getRetrievers() { return List.of( new RetrieverSpec<>(new ParseField(NAME), RRFRetrieverBuilder::fromXContent), - new RetrieverSpec<>(new ParseField(LinearRetrieverBuilder.NAME), LinearRetrieverBuilder::fromXContent)); + new RetrieverSpec<>(new ParseField(LinearRetrieverBuilder.NAME), LinearRetrieverBuilder::fromXContent) + ); } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index a749a7c402c30..c6eb702e6fe84 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -101,6 +101,7 @@ public String getName() { protected RRFRetrieverBuilder clone(List newRetrievers, List newPreFilterQueryBuilders) { RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant); clone.preFilterQueryBuilders = newPreFilterQueryBuilders; + clone.retrieverName = retrieverName; return clone; } From 5b253aa498626d6f50d99ab14bf9c200749be9ba Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 22 Jan 2025 13:05:12 +0200 Subject: [PATCH 35/57] iter --- .../search.retrievers/40_linear_retriever.yml | 1028 ----------------- 1 file changed, 1028 deletions(-) delete mode 100644 rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml deleted file mode 100644 index d82d2f4d41c4f..0000000000000 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.retrievers/40_linear_retriever.yml +++ /dev/null @@ -1,1028 +0,0 @@ -setup: - - requires: - cluster_features: [ "linear_retriever_supported" ] - reason: "Support for linear retriever" - test_runner_features: close_to - - - do: - indices.create: - index: test - body: - mappings: - properties: - vector: - type: dense_vector - dims: 1 - index: true - similarity: l2_norm - keyword: - type: keyword - other_keyword: - type: keyword - timestamp: - type: date - - - do: - bulk: - refresh: true - index: test - body: - - '{"index": {"_id": 1 }}' - - '{"vector": [1], "keyword": "one", "other_keyword": "other", "timestamp": "2021-01-01T00:00:00"}' - - '{"index": {"_id": 2 }}' - - '{"vector": [2], "keyword": "two", "timestamp": "2022-01-01T00:00:00"}' - - '{"index": {"_id": 3 }}' - - '{"vector": [3], "keyword": "three", "timestamp": "2023-01-01T00:00:00"}' - - '{"index": {"_id": 4 }}' - - '{"vector": [4], "keyword": "four", "other_keyword": "other", "timestamp": "2024-01-01T00:00:00"}' - ---- -"basic linear weighted combination of a standard and knn retrievers": - - do: - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } - } - }, - boost: 10.0 - } - } - } - }, - weight: 0.5 - }, - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } - ] - - - match: { hits.total.value: 2 } - - match: { hits.hits.0._id: "1" } - - match: { hits.hits.0._score: 5.0 } - - match: { hits.hits.1._id: "4" } - - match: { hits.hits.1._score: 2.0 } - ---- -"should normalize initial scores": - - do: - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } - } - }, - boost: 10.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } - } - }, - boost: 9.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "three" - } - } - }, - boost: 5.0 - } - } - ] - } - } - } - }, - weight: 10.0, - normalizer: "minmax" - }, - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } - ] - - - match: { hits.total.value: 4 } - - match: { hits.hits.0._id: "1" } - - match: {hits.hits.0._score: 10.0} - - match: { hits.hits.1._id: "2" } - - match: {hits.hits.1._score: 8.0} - - match: { hits.hits.2._id: "4" } - - match: {hits.hits.2._score: 2.0} - - match: { hits.hits.2._score: 2.0 } - - match: { hits.hits.3._id: "3" } - - close_to: { hits.hits.3._score: { value: 0.0, error: 0.001 } } - ---- -"should normalize initial scores with a custom minmax normalizer": - - do: - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } - } - }, - boost: 10.0 # normalized score for this would be -0.55 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } - } - }, - boost: 9.0 # normalized score for this would be -0.56 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "three" - } - } - }, - boost: 5.0 # normalized score for this would be -0.63 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "four" - } - } - }, - boost: 1.0 # normalized score for this would be -0.7 - } - } - ] - } - } - } - }, - normalizer: { - minmax: { - min: 42, - max: 100 - } - } - }, - { - # this only provides a score of 10 for doc 4 - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 10.0 - } - ] - - - match: { hits.total.value: 4 } - - match: { hits.hits.0._id: "4" } - - close_to: { hits.hits.0._score: { value: 50.2931, error: 0.001 } } - - match: { hits.hits.1._id: "1" } - - close_to: { hits.hits.1._score: { value: 40.4482, error: 0.001 } } - - match: { hits.hits.2._id: "2" } - - close_to: { hits.hits.2._score: { value: 40.4310, error: 0.001 } } - - match: { hits.hits.3._id: "3" } - - close_to: { hits.hits.3._score: { value: 40.3620, error: 0.001 } } - ---- -"should throw on unknown normalizer": - - do: - catch: /Unknown normalizer \[aardvark\]/ - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } - } - }, - boost: 10.0 - } - } - } - }, - weight: 1.0, - normalizer: { - aardvark: { } - } - }, - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } - ] - ---- -"pagination within a consistent rank_window_size": - - do: - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } - } - }, - boost: 10.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } - } - }, - boost: 9.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "three" - } - } - }, - boost: 5.0 - } - } - ] - } - } - } - }, - weight: 10.0, - normalizer: "minmax" - }, - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } - ] - from: 2 - size: 1 - - - match: { hits.total.value: 4 } - - length: { hits.hits: 1 } - - match: { hits.hits.0._id: "4" } - - match: { hits.hits.0._score: 2.0 } - - - do: - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } - } - }, - boost: 10.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } - } - }, - boost: 9.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "three" - } - } - }, - boost: 5.0 - } - } - ] - } - } - } - }, - weight: 10.0, - normalizer: "minmax" - }, - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } - ] - from: 3 - size: 1 - - - match: { hits.total.value: 4 } - - match: { hits.hits.0._id: "3" } - - close_to: { hits.hits.0._score: { value: 0.0, error: 0.001 } } - ---- -"should throw when rank_window_size less than size": - - do: - catch: "/\\[linear\\] requires \\[rank_window_size: 2\\] be greater than or equal to \\[size: 10\\]/" - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - match_all: { } - } - } - }, - weight: 10.0, - normalizer: "minmax" - }, - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } - ] - rank_window_size: 2 - size: 10 ---- -"should respect rank_window_size for normalization and returned hits": - - do: - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } - } - }, - boost: 10.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } - } - }, - boost: 9.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "three" - } - } - }, - boost: 5.0 - } - } - ] - } - } - } - }, - weight: 1.0, - normalizer: "minmax" - }, - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } - ] - rank_window_size: 2 - size: 2 - - - match: { hits.total.value: 4 } - - match: { hits.hits.0._id: "4" } - - match: { hits.hits.0._score: 2.0 } - - match: { hits.hits.1._id: "1" } - - match: { hits.hits.1._score: 1.0 } - ---- -"explain should provide info on weights and inner retrievers": - - do: - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } - } - }, - boost: 10.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "four" - } - } - }, - boost: 1.0 - } - } - ] - } - }, - _name: "my_standard_retriever" - } - }, - weight: 10.0, - normalizer: "minmax" - }, - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 20.0 - } - ] - explain: true - size: 2 - - - match: { hits.hits.0._id: "4" } - - match: { hits.hits.0._explanation.description: "/weighted.linear.combination.score:.\\[20.0].computed.for.normalized.scores.\\[.*,.1.0\\].and.weights.\\[10.0,.20.0\\].as.sum.of.\\(weight\\[i\\].*.score\\[i\\]\\).for.each.query./"} - - match: { hits.hits.0._explanation.details.0.value: 0.0 } - - match: { hits.hits.0._explanation.details.0.description: "/.*weighted.score.*result.not.found.in.query.at.index.\\[0\\].\\[my_standard_retriever\\]/" } - - match: { hits.hits.0._explanation.details.1.value: 20.0 } - - match: { hits.hits.0._explanation.details.1.description: "/.*weighted.score.*using.score.normalizer.\\[none\\].*/" } - - match: { hits.hits.1._id: "1" } - - match: { hits.hits.1._explanation.description: "/weighted.linear.combination.score:.\\[10.0].computed.for.normalized.scores.\\[1.0,.0.0\\].and.weights.\\[10.0,.20.0\\].as.sum.of.\\(weight\\[i\\].*.score\\[i\\]\\).for.each.query./"} - - match: { hits.hits.1._explanation.details.0.value: 10.0 } - - match: { hits.hits.1._explanation.details.0.description: "/.*weighted.score.*\\[my_standard_retriever\\].*using.score.normalizer.\\[minmax\\].*/" } - - match: { hits.hits.1._explanation.details.1.value: 0.0 } - - match: { hits.hits.1._explanation.details.1.description: "/.*weighted.score.*result.not.found.in.query.at.index.\\[1\\]/" } - ---- -"collapsing results": - - do: - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } - } - }, - boost: 10.0 - } - } - } - }, - weight: 0.5 - }, - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } - ] - collapse: - field: other_keyword - inner_hits: { - name: sub_hits, - sort: - { - keyword: { - order: desc - } - } - } - - match: { hits.hits.0._id: "1" } - - length: { hits.hits.0.inner_hits.sub_hits.hits.hits : 2 } - - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.0._id: "1" } - - match: { hits.hits.0.inner_hits.sub_hits.hits.hits.1._id: "4" } - ---- -"multiple nested linear retrievers": - - do: - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } - } - }, - boost: 10.0 - } - } - } - }, - weight: 0.5 - }, - { - retriever: { - linear: { - retrievers: [ - { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } - } - }, - boost: 20.0 - } - } - } - } - }, - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - } - } - ] - } - }, - weight: 2.0 - } - ] - - - match: { hits.total.value: 3 } - - match: { hits.hits.0._id: "2" } - - match: { hits.hits.0._score: 40.0 } - - match: { hits.hits.1._id: "1" } - - match: { hits.hits.1._score: 5.0 } - - match: { hits.hits.2._id: "4" } - - match: { hits.hits.2._score: 2.0 } - ---- -"linear retriever with filters": - - do: - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } - } - }, - boost: 10.0 - } - } - } - }, - weight: 0.5 - }, - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } - ] - filter: - term: - keyword: "four" - - - - match: { hits.total.value: 1 } - - length: {hits.hits: 1} - - match: { hits.hits.0._id: "4" } - - match: { hits.hits.0._score: 2.0 } - ---- -"linear retriever with filters on nested retrievers": - - do: - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } - } - }, - boost: 10.0 - } - }, - filter: { - term: { - keyword: "four" - } - } - } - }, - weight: 0.5 - }, - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 - } - ] - - - match: { hits.total.value: 1 } - - length: {hits.hits: 1} - - match: { hits.hits.0._id: "4" } - - match: { hits.hits.0._score: 2.0 } - - ---- -"linear retriever with custom sort and score for nested retrievers": - - do: - search: - index: test - body: - retriever: - linear: - retrievers: [ - { - retriever: { - standard: { - query: { - constant_score: { - filter: { - bool: { - should: [ - { - term: { - keyword: { - value: "one" # this will give doc 1 a normalized score of 10 - } - } - }, - { - term: { - keyword: { - value: "two" # this will give doc 2 a normalized score of 10 - } - } - } ] - } - }, - boost: 10.0 - } - }, - sort: { - timestamp: { - order: "asc" - } - } - } - }, - weight: 1.0, - normalizer: "minmax" - }, - { - # because we're sorting on timestamp and use a rank window size of 2, we will only get to see - # docs 3 and 2. - # their `scores` (which are the timestamps) are: - # doc 3: 1672531200000 (2023-01-01T00:00:00) - # doc 2: 1640995200000 (2022-01-01T00:00:00) - # and their normalized scores based on the provided conf - # will be: - # normalized(doc3) = 0.59989 - # normalized(doc2) = 0.40010 - retriever: { - standard: { - query: { - function_score: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "one" - } - } - }, - boost: 10.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "two" - } - } - }, - boost: 9.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "three" - } - } - }, - boost: 1.0 - } - } - ] - } - }, - functions: [ { - script_score: { - script: { - source: "doc['timestamp'].value.millis" - } - } - } ], - "boost_mode": "replace" - } - }, - sort: { - timestamp: { - order: "desc" - } - } - } - }, - weight: 1.0, - normalizer: { - minmax: { - min: 1577836800000, # 2020-01-01T00:00:00 - max: 1735689600000 # 2025-01-01T00:00:00 - } - } - } - ] - rank_window_size: 2 - size: 2 - - - match: { hits.total.value: 3 } - - length: {hits.hits: 2} - - match: { hits.hits.0._id: "2" } - - close_to: { hits.hits.0._score: { value: 10.4001, error: 0.001 } } - - match: { hits.hits.1._id: "1" } - - match: { hits.hits.1._score: 10 } From 1eca5fec0d7a29452be1c24089e549a108ad0525 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 22 Jan 2025 13:06:22 +0200 Subject: [PATCH 36/57] add license test --- .../test/license/100_license.yml | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/license/100_license.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/license/100_license.yml index cd227eec4e227..ca1ccf679ff81 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/license/100_license.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/license/100_license.yml @@ -111,3 +111,40 @@ setup: - match: { status: 403 } - match: { error.type: security_exception } - match: { error.reason: "current license is non-compliant for [Reciprocal Rank Fusion (RRF)]" } + + +--- +"linear retriever invalid license": + + - do: + catch: forbidden + search: + index: test + body: + track_total_hits: false + fields: [ "text" ] + retriever: + linear: + retrievers: [ + { + knn: { + field: vector, + query_vector: [ 0.0 ], + k: 3, + num_candidates: 3 + } + }, + { + standard: { + query: { + term: { + text: term + } + } + } + } + ] + + - match: { status: 403 } + - match: { error.type: security_exception } + - match: { error.reason: "current license is non-compliant for [linear retriever]" } From 1f36e18bf2ced3c47ee44352ca2d7a68470296c6 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 22 Jan 2025 13:22:34 +0200 Subject: [PATCH 37/57] checkstyle --- .../xpack/rank/linear/LinearRetrieverIT.java | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java index 99514cdbdc337..26561be038305 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java @@ -485,26 +485,30 @@ public void testLinearExplainWithNamedRetrievers() { assertThat( rrfDetails.getDescription(), equalTo( - "weighted linear combination score: [30.0] computed for normalized scores [9.0, 20.0, 1.0] and weights [1.0, 1.0, 1.0] as sum of (weight[i] * score[i]) for each query." + "weighted linear combination score: [30.0] computed for normalized scores [9.0, 20.0, 1.0] " + + "and weights [1.0, 1.0, 1.0] as sum of (weight[i] * score[i]) for each query." ) ); assertThat( rrfDetails.getDetails()[0].getDescription(), containsString( - "weighted score: [9.0] in query at index [0] [my_custom_retriever] computed as [1.0 * 9.0] using score normalizer [none] for original matching query with score" + "weighted score: [9.0] in query at index [0] [my_custom_retriever] computed as [1.0 * 9.0] " + + "using score normalizer [none] for original matching query with score" ) ); assertThat( rrfDetails.getDetails()[1].getDescription(), containsString( - "weighted score: [20.0] in query at index [1] computed as [1.0 * 20.0] using score normalizer [none] for original matching query with score:" + "weighted score: [20.0] in query at index [1] computed as [1.0 * 20.0] using score normalizer [none] " + + "for original matching query with score:" ) ); assertThat( rrfDetails.getDetails()[2].getDescription(), containsString( - "weighted score: [1.0] in query at index [2] computed as [1.0 * 1.0] using score normalizer [none] for original matching query with score" + "weighted score: [1.0] in query at index [2] computed as [1.0 * 1.0] using score normalizer [none] " + + "for original matching query with score" ) ); }); @@ -585,7 +589,8 @@ public void testLinearExplainWithAnotherNestedLinear() { assertThat( linearTopLevel.getDescription(), containsString( - "weighted linear combination score: [112.05882] computed for normalized scores [12.058824, 20.0] and weights [1.0, 5.0] as sum of (weight[i] * score[i]) for each query." + "weighted linear combination score: [112.05882] computed for normalized scores [12.058824, 20.0] " + + "and weights [1.0, 5.0] as sum of (weight[i] * score[i]) for each query." ) ); assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("weighted score: [12.058824]")); @@ -669,7 +674,8 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder assertThat( ex.getMessage(), containsString( - "[linear] search failed - retrievers '[standard, test]' returned errors. All failures are attached as suppressed exceptions." + "[linear] search failed - retrievers '[standard, test]' returned errors. " + + "All failures are attached as suppressed exceptions." ) ); assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.INTERNAL_SERVER_ERROR)); From 33bc324f5261e85e3dcf33513c632d1e535fbb3b Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 22 Jan 2025 11:30:13 +0000 Subject: [PATCH 38/57] [CI] Auto commit changes from spotless --- .../xpack/rank/linear/LinearRetrieverIT.java | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java index 26561be038305..b42c84516ed6b 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java @@ -485,30 +485,30 @@ public void testLinearExplainWithNamedRetrievers() { assertThat( rrfDetails.getDescription(), equalTo( - "weighted linear combination score: [30.0] computed for normalized scores [9.0, 20.0, 1.0] " + - "and weights [1.0, 1.0, 1.0] as sum of (weight[i] * score[i]) for each query." + "weighted linear combination score: [30.0] computed for normalized scores [9.0, 20.0, 1.0] " + + "and weights [1.0, 1.0, 1.0] as sum of (weight[i] * score[i]) for each query." ) ); assertThat( rrfDetails.getDetails()[0].getDescription(), containsString( - "weighted score: [9.0] in query at index [0] [my_custom_retriever] computed as [1.0 * 9.0] " + - "using score normalizer [none] for original matching query with score" + "weighted score: [9.0] in query at index [0] [my_custom_retriever] computed as [1.0 * 9.0] " + + "using score normalizer [none] for original matching query with score" ) ); assertThat( rrfDetails.getDetails()[1].getDescription(), containsString( - "weighted score: [20.0] in query at index [1] computed as [1.0 * 20.0] using score normalizer [none] " + - "for original matching query with score:" + "weighted score: [20.0] in query at index [1] computed as [1.0 * 20.0] using score normalizer [none] " + + "for original matching query with score:" ) ); assertThat( rrfDetails.getDetails()[2].getDescription(), containsString( - "weighted score: [1.0] in query at index [2] computed as [1.0 * 1.0] using score normalizer [none] " + - "for original matching query with score" + "weighted score: [1.0] in query at index [2] computed as [1.0 * 1.0] using score normalizer [none] " + + "for original matching query with score" ) ); }); @@ -589,8 +589,8 @@ public void testLinearExplainWithAnotherNestedLinear() { assertThat( linearTopLevel.getDescription(), containsString( - "weighted linear combination score: [112.05882] computed for normalized scores [12.058824, 20.0] " + - "and weights [1.0, 5.0] as sum of (weight[i] * score[i]) for each query." + "weighted linear combination score: [112.05882] computed for normalized scores [12.058824, 20.0] " + + "and weights [1.0, 5.0] as sum of (weight[i] * score[i]) for each query." ) ); assertThat(linearTopLevel.getDetails()[0].getDescription(), containsString("weighted score: [12.058824]")); @@ -674,8 +674,8 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder assertThat( ex.getMessage(), containsString( - "[linear] search failed - retrievers '[standard, test]' returned errors. " + - "All failures are attached as suppressed exceptions." + "[linear] search failed - retrievers '[standard, test]' returned errors. " + + "All failures are attached as suppressed exceptions." ) ); assertThat(ExceptionsHelper.status(ex), equalTo(RestStatus.INTERNAL_SERVER_ERROR)); From 4d82e28bbe6820b5db1aeec68039f33fc1474181 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 22 Jan 2025 13:59:31 +0200 Subject: [PATCH 39/57] moving tests --- .../xpack/rank/linear/LinearRankDocTests.java | 0 .../rank/linear/LinearRetrieverBuilderParsingTests.java | 7 +------ 2 files changed, 1 insertion(+), 6 deletions(-) rename x-pack/plugin/rank-rrf/src/{internalClusterTest => test}/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java (100%) rename x-pack/plugin/rank-rrf/src/{internalClusterTest => test}/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java (89%) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java similarity index 100% rename from x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java rename to x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java similarity index 89% rename from x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java rename to x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java index f868d60ea20bf..ca198a188f896 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java @@ -9,12 +9,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.search.SearchModule; -import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; -import org.elasticsearch.search.retriever.MinMaxScoreNormalizer; -import org.elasticsearch.search.retriever.RetrieverBuilder; -import org.elasticsearch.search.retriever.RetrieverParserContext; -import org.elasticsearch.search.retriever.ScoreNormalizer; -import org.elasticsearch.search.retriever.TestRetrieverBuilder; +import org.elasticsearch.search.retriever.*; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.xcontent.NamedXContentRegistry; From cfcd84f981edd221eec8b949bb247f51fb6c3876 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 22 Jan 2025 14:07:41 +0200 Subject: [PATCH 40/57] checkstyle --- .../rank/linear/LinearRetrieverBuilderParsingTests.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java index ca198a188f896..f868d60ea20bf 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java @@ -9,7 +9,12 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.search.SearchModule; -import org.elasticsearch.search.retriever.*; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.MinMaxScoreNormalizer; +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.retriever.ScoreNormalizer; +import org.elasticsearch.search.retriever.TestRetrieverBuilder; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.xcontent.NamedXContentRegistry; From 174e0d0d4f6aec6432cb8a0a88ee4f7382187a44 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Wed, 22 Jan 2025 15:26:48 +0200 Subject: [PATCH 41/57] adding missing writeables for tests --- .../xpack/rank/linear/LinearRankDocTests.java | 8 ++++++-- .../rank/linear/LinearRetrieverBuilderParsingTests.java | 8 ++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java index b7bef70ce3024..fb83687b260d1 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java @@ -10,9 +10,9 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase; +import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin; import java.io.IOException; -import java.util.Collections; import java.util.List; public class LinearRankDocTests extends AbstractRankDocWireSerializingTestCase { @@ -35,7 +35,11 @@ protected LinearRankDoc createTestRankDoc() { @Override protected List getAdditionalNamedWriteables() { - return Collections.emptyList(); + try (RRFRankPlugin rrfRankPlugin = new RRFRankPlugin()) { + return rrfRankPlugin.getNamedWriteables(); + } catch (IOException ex) { + throw new AssertionError("Failed to create RRFRankPlugin", ex); + } } @Override diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java index f868d60ea20bf..08daaf2a2101f 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -82,6 +83,13 @@ protected NamedXContentRegistry xContentRegistry() { TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion() ) ); + entries.add( + new NamedXContentRegistry.Entry( + RetrieverBuilder.class, + new ParseField(LinearRetrieverBuilder.NAME), + (p, c) -> LinearRetrieverBuilder.PARSER.apply(p, (RetrieverParserContext) c) + ) + ); return new NamedXContentRegistry(entries); } From aeacd33a63a0c9fb57d0aa0d03f8af7eca1a3af5 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 23 Jan 2025 12:55:30 +0200 Subject: [PATCH 42/57] addressing PR comments --- .../retrievers-examples.asciidoc | 4 +- .../retriever/CompoundRetrieverBuilder.java | 1 - .../xpack/rank/linear/LinearRetrieverIT.java | 2 - .../rank/linear}/IdentityScoreNormalizer.java | 10 ++-- .../xpack/rank/linear/LinearRankDoc.java | 49 ++++++++++++------- .../rank/linear/LinearRetrieverBuilder.java | 32 ++++++------ .../rank/linear/LinearRetrieverComponent.java | 2 - .../rank/linear}/MinMaxScoreNormalizer.java | 14 +++--- .../xpack/rank/linear}/ScoreNormalizer.java | 10 ++-- .../xpack/rank/linear/LinearRankDocTests.java | 11 +++-- .../LinearRetrieverBuilderParsingTests.java | 2 - .../test/license/100_license.yml | 3 ++ 12 files changed, 73 insertions(+), 67 deletions(-) rename {server/src/main/java/org/elasticsearch/search/retriever => x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear}/IdentityScoreNormalizer.java (75%) rename {server/src/main/java/org/elasticsearch/search/retriever => x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear}/MinMaxScoreNormalizer.java (87%) rename {server/src/main/java/org/elasticsearch/search/retriever => x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear}/ScoreNormalizer.java (81%) diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index 2e90100c86bda..c195ca62bbf79 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -202,8 +202,8 @@ retrievers using a weighted sum of the original scores. Since, as above, the sco we can also specify a `normalizer` that would ensure that all scores for the top ranked documents of a retriever lie in a specific range. -To implement this, we define a `linear` retriever, and a set of `components` as the nested heterogeneous results sets -that we will combine. We will solve a problem similar to the above, by merging the results of a `standard` and a `knn` +To implement this, we define a `linear` retriever, and along with a set of retrievers that will generate the heterogeneous +results sets that we will combine. We will solve a problem similar to the above, by merging the results of a `standard` and a `knn` retriever. As the `standard` retriever's scores are based on BM25 and are not strictly bounded, we will also define a `minmax` normalizer to ensure that the scores lie in the [0, 1] range. We will apply the same normalizer to `knn` as well to ensure that we capture the importance of each document within the result set. diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 902a05b8e5c91..53560e129ca5d 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -282,7 +282,6 @@ public int doHashCode() { protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) .trackTotalHits(false) - .trackScores(true) .storedFields(new StoredFieldsContext(false)) .size(rankWindowSize); // apply the pre-filters downstream once diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java index b42c84516ed6b..01ac8450f43d6 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java @@ -26,9 +26,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; -import org.elasticsearch.search.retriever.IdentityScoreNormalizer; import org.elasticsearch.search.retriever.KnnRetrieverBuilder; -import org.elasticsearch.search.retriever.ScoreNormalizer; import org.elasticsearch.search.retriever.StandardRetrieverBuilder; import org.elasticsearch.search.retriever.TestRetrieverBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; diff --git a/server/src/main/java/org/elasticsearch/search/retriever/IdentityScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java similarity index 75% rename from server/src/main/java/org/elasticsearch/search/retriever/IdentityScoreNormalizer.java rename to x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java index 68f0507ff3397..7b1f70b821b14 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/IdentityScoreNormalizer.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java @@ -1,13 +1,11 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ -package org.elasticsearch.search.retriever; +package org.elasticsearch.xpack.rank.linear; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.xcontent.ConstructingObjectParser; diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java index 19eab1dc19e73..e0fd23d9c5aa3 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java @@ -19,6 +19,10 @@ import java.util.Arrays; import java.util.Objects; +import static org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder.DEFAULT_SCORE; +import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_NORMALIZER; +import static org.elasticsearch.xpack.rank.linear.LinearRetrieverComponent.DEFAULT_WEIGHT; + public class LinearRankDoc extends RankDoc { public static final String NAME = "linear_rank_doc"; @@ -31,36 +35,41 @@ public LinearRankDoc(int doc, float score, int shardIndex, float[] weights, Stri super(doc, score, shardIndex); this.weights = weights; this.normalizers = normalizers; - this.normalizedScores = new float[normalizers.length]; } public LinearRankDoc(StreamInput in) throws IOException { super(in); - weights = in.readFloatArray(); - normalizedScores = in.readFloatArray(); - normalizers = in.readStringArray(); + weights = in.readOptionalFloatArray(); + normalizedScores = in.readOptionalFloatArray(); + normalizers = in.readOptionalStringArray(); } @Override public Explanation explain(Explanation[] sources, String[] queryNames) { + assert normalizedScores != null; + assert normalizedScores.length == sources.length; + Explanation[] details = new Explanation[sources.length]; for (int i = 0; i < sources.length; i++) { final String queryAlias = queryNames[i] == null ? "" : " [" + queryNames[i] + "]"; final String queryIdentifier = "at index [" + i + "]" + queryAlias; - if (normalizedScores[i] > 0) { + final float weight = weights == null ? DEFAULT_WEIGHT : weights[i]; + final float normalizedScore = normalizedScores == null ? DEFAULT_SCORE : normalizedScores[i]; + final String normalizer = normalizers == null ? DEFAULT_NORMALIZER.getName() : normalizers[i]; + if (normalizedScore > 0) { details[i] = Explanation.match( - weights[i] * normalizedScores[i], + weight * normalizedScore, "weighted score: [" - + weights[i] * normalizedScores[i] + + weight * normalizedScore + "] in query " + queryIdentifier + " computed as [" - + weights[i] + + weight + " * " - + normalizedScores[i] + + normalizedScore + "]" + " using score normalizer [" - + normalizers[i] + + normalizer + "]" + " for original matching query with score:", sources[i] @@ -77,7 +86,7 @@ public Explanation explain(Explanation[] sources, String[] queryNames) { + "] computed for normalized scores " + Arrays.toString(normalizedScores) + " and weights " - + Arrays.toString(weights) + + Arrays.toString(weights == null ? new float[sources.length] : weights) + " as sum of (weight[i] * score[i]) for each query.", details ); @@ -85,16 +94,22 @@ public Explanation explain(Explanation[] sources, String[] queryNames) { @Override protected void doWriteTo(StreamOutput out) throws IOException { - out.writeFloatArray(weights); - out.writeFloatArray(normalizedScores); - out.writeStringArray(normalizers); + out.writeOptionalFloatArray(weights); + out.writeOptionalFloatArray(normalizedScores); + out.writeOptionalStringArray(normalizers); } @Override protected void doToXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("weights", weights); - builder.field("normalizedScores", normalizedScores); - builder.field("normalizers", normalizers); + if (weights != null) { + builder.field("weights", weights); + } + if (normalizedScores != null) { + builder.field("normalizedScores", normalizedScores); + } + if (normalizers != null) { + builder.field("normalizers", normalizers); + } } @Override diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index 65c0c8c41ef38..918063c583deb 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -15,10 +15,8 @@ import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; -import org.elasticsearch.search.retriever.IdentityScoreNormalizer; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; -import org.elasticsearch.search.retriever.ScoreNormalizer; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; @@ -51,6 +49,8 @@ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder newChildRetrievers, @Override protected RankDoc[] combineInnerRetrieverResults(List rankResults) { Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); - final String[] normalizerNames = Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new); + final String[] normalizerNames = normalizers == null + ? null + : Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new); for (int result = 0; result < rankResults.size(); result++) { + final ScoreNormalizer normalizer = normalizers == null ? IdentityScoreNormalizer.INSTANCE : normalizers[result]; ScoreDoc[] originalScoreDocs = rankResults.get(result); - ScoreDoc[] normalizedScoreDocs = normalizers[result].normalizeScores(originalScoreDocs); + ScoreDoc[] normalizedScoreDocs = normalizer.normalizeScores(originalScoreDocs); for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; scoreDocIndex++) { int finalResult = result; int finalScoreIndex = scoreDocIndex; @@ -151,9 +144,14 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { weights, normalizerNames ); + value.normalizedScores = new float[rankResults.size()]; } + final float docScore = false == Float.isNaN(normalizedScoreDocs[finalScoreIndex].score) + ? normalizedScoreDocs[finalScoreIndex].score + : DEFAULT_SCORE; + final float weight = weights == null ? DEFAULT_WEIGHT : weights[finalResult]; value.normalizedScores[finalResult] = normalizedScoreDocs[finalScoreIndex].score; - value.score += weights[finalResult] * normalizedScoreDocs[finalScoreIndex].score; + value.score += weight * docScore; return value; } ); diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java index 168c94ddda6c9..7e3c02ccca115 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java @@ -8,10 +8,8 @@ package org.elasticsearch.xpack.rank.linear; import org.elasticsearch.common.ParsingException; -import org.elasticsearch.search.retriever.IdentityScoreNormalizer; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; -import org.elasticsearch.search.retriever.ScoreNormalizer; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; diff --git a/server/src/main/java/org/elasticsearch/search/retriever/MinMaxScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java similarity index 87% rename from server/src/main/java/org/elasticsearch/search/retriever/MinMaxScoreNormalizer.java rename to x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java index bfd0ba35cbca0..3f0cffab9a0a5 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/MinMaxScoreNormalizer.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java @@ -1,13 +1,11 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ -package org.elasticsearch.search.retriever; +package org.elasticsearch.xpack.rank.linear; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.xcontent.ConstructingObjectParser; @@ -26,6 +24,8 @@ public class MinMaxScoreNormalizer extends ScoreNormalizer { public static final ParseField MIN_FIELD = new ParseField("min"); public static final ParseField MAX_FIELD = new ParseField("max"); + private static final float EPSILON = 1e-6f; + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { Float min = (Float) args[0]; Float max = (Float) args[1]; @@ -90,7 +90,7 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { if (min > max) { throw new IllegalArgumentException("[min=" + min + "] must be less than [max=" + max + "]"); } - boolean minEqualsMax = min.equals(max); + boolean minEqualsMax = Math.abs(min - max) < EPSILON; for (int i = 0; i < docs.length; i++) { float score; if (minEqualsMax) { diff --git a/server/src/main/java/org/elasticsearch/search/retriever/ScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java similarity index 81% rename from server/src/main/java/org/elasticsearch/search/retriever/ScoreNormalizer.java rename to x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java index ed1b0f21b785c..137b3aa2b8446 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/ScoreNormalizer.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java @@ -1,13 +1,11 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the "Elastic License - * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side - * Public License v 1"; you may not use this file except in compliance with, at - * your election, the "Elastic License 2.0", the "GNU Affero General Public - * License v3.0 only", or the "Server Side Public License, v 1". + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. */ -package org.elasticsearch.search.retriever; +package org.elasticsearch.xpack.rank.linear; import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.xcontent.ToXContent; diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java index fb83687b260d1..051aa6bddb4d7 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRankDocTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin; import java.io.IOException; @@ -59,16 +60,16 @@ protected LinearRankDoc mutateInstance(LinearRankDoc instance) throws IOExceptio mutated.normalizedScores = instance.normalizedScores; mutated.rank = instance.rank; if (frequently()) { - mutated.doc = randomNonNegativeInt(); + mutated.doc = randomValueOtherThan(instance.doc, ESTestCase::randomNonNegativeInt); } if (frequently()) { - mutated.score = randomFloat(); + mutated.score = randomValueOtherThan(instance.score, ESTestCase::randomFloat); } if (frequently()) { - mutated.shardIndex = randomNonNegativeInt(); + mutated.shardIndex = randomValueOtherThan(instance.shardIndex, ESTestCase::randomNonNegativeInt); } if (frequently()) { - mutated.rank = randomNonNegativeInt(); + mutated.rank = randomValueOtherThan(instance.rank, ESTestCase::randomNonNegativeInt); } if (frequently()) { for (int i = 0; i < mutated.normalizedScores.length; i++) { @@ -87,7 +88,7 @@ protected LinearRankDoc mutateInstance(LinearRankDoc instance) throws IOExceptio if (frequently()) { for (int i = 0; i < mutated.normalizers.length; i++) { if (frequently()) { - mutated.normalizers[i] = randomAlphaOfLengthBetween(1, 10); + mutated.normalizers[i] = randomValueOtherThan(instance.normalizers[i], () -> randomAlphaOfLengthBetween(1, 10)); } } } diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java index 08daaf2a2101f..fc2643d3a5c79 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java @@ -10,10 +10,8 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; -import org.elasticsearch.search.retriever.MinMaxScoreNormalizer; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; -import org.elasticsearch.search.retriever.ScoreNormalizer; import org.elasticsearch.search.retriever.TestRetrieverBuilder; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/license/100_license.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/license/100_license.yml index ca1ccf679ff81..42d0fa1998246 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/license/100_license.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/license/100_license.yml @@ -115,6 +115,9 @@ setup: --- "linear retriever invalid license": + - requires: + cluster_features: [ "linear_retriever_supported" ] + reason: "Support for linear retriever" - do: catch: forbidden From 58e2887b8349586e1a83bf020ce39c44e0132051 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 23 Jan 2025 13:38:54 +0200 Subject: [PATCH 43/57] fixing tests after refactoring --- .../xpack/rank/linear/LinearRetrieverIT.java | 2 +- .../elasticsearch/xpack/rank/linear/LinearRankDoc.java | 3 +-- .../xpack/rank/linear/LinearRetrieverBuilder.java | 9 +++++++++ .../xpack/rank/linear/MinMaxScoreNormalizer.java | 9 +++++++++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java index 01ac8450f43d6..6b452f6cd3f0c 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java @@ -484,7 +484,7 @@ public void testLinearExplainWithNamedRetrievers() { rrfDetails.getDescription(), equalTo( "weighted linear combination score: [30.0] computed for normalized scores [9.0, 20.0, 1.0] " - + "and weights [1.0, 1.0, 1.0] as sum of (weight[i] * score[i]) for each query." + + "as sum of (weight[i] * score[i]) for each query." ) ); diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java index e0fd23d9c5aa3..a1b8943517f93 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java @@ -85,8 +85,7 @@ public Explanation explain(Explanation[] sources, String[] queryNames) { + score + "] computed for normalized scores " + Arrays.toString(normalizedScores) - + " and weights " - + Arrays.toString(weights == null ? new float[sources.length] : weights) + + (weights == null ? "" : " and weights " + Arrays.toString(weights)) + " as sum of (weight[i] * score[i]) for each query.", details ); diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index 918063c583deb..e46cccd9ff0d5 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; @@ -120,6 +121,12 @@ protected LinearRetrieverBuilder clone(List newChildRetrievers, return clone; } + @Override + protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) { + sourceBuilder.trackScores(true); + return sourceBuilder; + } + @Override protected RankDoc[] combineInnerRetrieverResults(List rankResults) { Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); @@ -146,6 +153,8 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { ); value.normalizedScores = new float[rankResults.size()]; } + // if we do not have scores associated with this result set, just ignore its contribution to the final + // score computation by setting its score to 0. final float docScore = false == Float.isNaN(normalizedScoreDocs[finalScoreIndex].score) ? normalizedScoreDocs[finalScoreIndex].score : DEFAULT_SCORE; diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java index 3f0cffab9a0a5..8f5865dcd9fea 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java @@ -68,7 +68,11 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { float correction = 0f; float xMin = Float.MAX_VALUE; float xMax = Float.MIN_VALUE; + boolean atLeastOneValidScore = false; for (ScoreDoc rd : docs) { + if (false == atLeastOneValidScore && false == Float.isNaN(rd.score)) { + atLeastOneValidScore = true; + } if (rd.score > xMax) { xMax = rd.score; } @@ -76,6 +80,11 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { xMin = rd.score; } } + if (false == atLeastOneValidScore) { + // we do not have any scores to normalize, so we just return the original array + return docs; + } + if (min == null) { min = xMin; } else { From 29438eee00d2fd8ddb432bcf95e28b9859ee068a Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 23 Jan 2025 13:41:56 +0200 Subject: [PATCH 44/57] iter --- .../xpack/rank/linear/LinearRetrieverBuilder.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index e46cccd9ff0d5..9ec28838ac5b3 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -134,7 +134,7 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { ? null : Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new); for (int result = 0; result < rankResults.size(); result++) { - final ScoreNormalizer normalizer = normalizers == null ? IdentityScoreNormalizer.INSTANCE : normalizers[result]; + final ScoreNormalizer normalizer = normalizers == null || normalizers[result] == null ? IdentityScoreNormalizer.INSTANCE : normalizers[result]; ScoreDoc[] originalScoreDocs = rankResults.get(result); ScoreDoc[] normalizedScoreDocs = normalizer.normalizeScores(originalScoreDocs); for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; scoreDocIndex++) { @@ -158,7 +158,7 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { final float docScore = false == Float.isNaN(normalizedScoreDocs[finalScoreIndex].score) ? normalizedScoreDocs[finalScoreIndex].score : DEFAULT_SCORE; - final float weight = weights == null ? DEFAULT_WEIGHT : weights[finalResult]; + final float weight = weights == null || Float.isNaN(weights[finalResult]) ? DEFAULT_WEIGHT : weights[finalResult]; value.normalizedScores[finalResult] = normalizedScoreDocs[finalScoreIndex].score; value.score += weight * docScore; return value; From 7d3f36c9d5abde1889baf1d2dacb220736ee1fa7 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 23 Jan 2025 11:48:27 +0000 Subject: [PATCH 45/57] [CI] Auto commit changes from spotless --- .../xpack/rank/linear/LinearRetrieverBuilder.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index 9ec28838ac5b3..19e0f6b605d37 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -134,7 +134,9 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { ? null : Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new); for (int result = 0; result < rankResults.size(); result++) { - final ScoreNormalizer normalizer = normalizers == null || normalizers[result] == null ? IdentityScoreNormalizer.INSTANCE : normalizers[result]; + final ScoreNormalizer normalizer = normalizers == null || normalizers[result] == null + ? IdentityScoreNormalizer.INSTANCE + : normalizers[result]; ScoreDoc[] originalScoreDocs = rankResults.get(result); ScoreDoc[] normalizedScoreDocs = normalizer.normalizeScores(originalScoreDocs); for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; scoreDocIndex++) { From 8677263aea34a2b21f14d1a7467f1c532298b029 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 23 Jan 2025 14:36:27 +0200 Subject: [PATCH 46/57] iter --- .../xpack/rank/linear/LinearRetrieverIT.java | 2 +- .../rank/linear/LinearRetrieverBuilder.java | 26 ++++-- .../rank/linear/LinearRetrieverComponent.java | 82 ++++++++----------- 3 files changed, 54 insertions(+), 56 deletions(-) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java index 6b452f6cd3f0c..2359a4dbc5758 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java @@ -179,7 +179,7 @@ public void testLinearRetrieverWithAggs() { // all requests would have an equal weight and use the identity normalizer source.retriever( - new org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder( + new LinearRetrieverBuilder( Arrays.asList( new CompoundRetrieverBuilder.RetrieverSource(standard0, null), new CompoundRetrieverBuilder.RetrieverSource(standard1, null), diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index 9ec28838ac5b3..308dc7d3881df 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -82,6 +82,18 @@ public final class LinearRetrieverBuilder extends CompoundRetrieverBuilder innerRetrievers, int rankWindowSize) { - this(innerRetrievers, rankWindowSize, null, null); + this(innerRetrievers, rankWindowSize, getDefaultWeight(innerRetrievers.size()), getDefaultNormalizers(innerRetrievers.size())); } public LinearRetrieverBuilder( @@ -103,10 +115,10 @@ public LinearRetrieverBuilder( ScoreNormalizer[] normalizers ) { super(innerRetrievers, rankWindowSize); - if (weights != null && weights.length != innerRetrievers.size()) { + if (weights.length != innerRetrievers.size()) { throw new IllegalArgumentException("The number of weights must match the number of inner retrievers"); } - if (normalizers != null && normalizers.length != innerRetrievers.size()) { + if (normalizers.length != innerRetrievers.size()) { throw new IllegalArgumentException("The number of normalizers must match the number of inner retrievers"); } this.weights = weights; @@ -130,11 +142,9 @@ protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBu @Override protected RankDoc[] combineInnerRetrieverResults(List rankResults) { Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); - final String[] normalizerNames = normalizers == null - ? null - : Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new); + final String[] normalizerNames = Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new); for (int result = 0; result < rankResults.size(); result++) { - final ScoreNormalizer normalizer = normalizers == null || normalizers[result] == null ? IdentityScoreNormalizer.INSTANCE : normalizers[result]; + final ScoreNormalizer normalizer = normalizers[result] == null ? IdentityScoreNormalizer.INSTANCE : normalizers[result]; ScoreDoc[] originalScoreDocs = rankResults.get(result); ScoreDoc[] normalizedScoreDocs = normalizer.normalizeScores(originalScoreDocs); for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; scoreDocIndex++) { @@ -158,7 +168,7 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { final float docScore = false == Float.isNaN(normalizedScoreDocs[finalScoreIndex].score) ? normalizedScoreDocs[finalScoreIndex].score : DEFAULT_SCORE; - final float weight = weights == null || Float.isNaN(weights[finalResult]) ? DEFAULT_WEIGHT : weights[finalResult]; + final float weight = Float.isNaN(weights[finalResult]) ? DEFAULT_WEIGHT : weights[finalResult]; value.normalizedScores[finalResult] = normalizedScoreDocs[finalScoreIndex].score; value.score += weight * docScore; return value; diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java index 7e3c02ccca115..23035b98f5a86 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java @@ -10,6 +10,8 @@ import org.elasticsearch.common.ParsingException; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -17,7 +19,8 @@ import java.io.IOException; -import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder.RETRIEVERS_FIELD; public class LinearRetrieverComponent implements ToXContentObject { @@ -46,52 +49,37 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public static LinearRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { - RetrieverBuilder retrieverBuilder = null; - Float weight = null; - ScoreNormalizer normalizer = null; - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { - if (RETRIEVER_FIELD.match(parser.currentName(), parser.getDeprecationHandler())) { - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - parser.nextToken(); - ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.currentToken(), parser); - final String retrieverName = parser.currentName(); - parser.nextToken(); - retrieverBuilder = parser.namedObject(RetrieverBuilder.class, retrieverName, context); - parser.nextToken(); - } else if (WEIGHT_FIELD.match(parser.currentName(), parser.getDeprecationHandler())) { - parser.nextToken(); - weight = parser.floatValue(); - } else if (NORMALIZER_FIELD.match(parser.currentName(), parser.getDeprecationHandler())) { - parser.nextToken(); - if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { - normalizer = ScoreNormalizer.valueOf(parser.text()); - } else if (parser.currentToken() == XContentParser.Token.START_OBJECT) { - parser.nextToken(); - normalizer = ScoreNormalizer.parse(parser.currentName(), parser); - parser.nextToken(); - } else { - throw new ParsingException(parser.getTokenLocation(), "Unsupported token [" + parser.currentToken() + "]"); - } - } else { - throw new ParsingException( - parser.getTokenLocation(), - "Unexpected token [" + parser.currentToken() + "] for linear retriever." - ); - } - } else { - throw new ParsingException( - parser.getTokenLocation(), - "Expected [" + XContentParser.Token.FIELD_NAME + "] but got [" + parser.currentToken() + "] instead." - ); - } + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "retriever-component", + false, + args -> { + RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[0]; + Float weight = (Float) args[1]; + ScoreNormalizer normalizer = (ScoreNormalizer) args[2]; + return new LinearRetrieverComponent(retrieverBuilder, weight, normalizer); } - if (retrieverBuilder == null) { - throw new IllegalArgumentException("Missing required field [" + RETRIEVER_FIELD.getPreferredName() + "]"); - } - return new LinearRetrieverComponent(retrieverBuilder, weight, normalizer); + ); + + static { + PARSER.declareNamedObject(constructorArg(), (p, c, n) -> { + RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c); + c.trackRetrieverUsage(innerRetriever.getName()); + return innerRetriever; + }, RETRIEVER_FIELD); + PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD); + PARSER.declareField(optionalConstructorArg(), (p, c) -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return ScoreNormalizer.valueOf(p.text()); + } else if (p.currentToken() == XContentParser.Token.START_OBJECT) { + p.nextToken(); + return ScoreNormalizer.parse(p.currentName(), p); + } + throw new ParsingException(p.getTokenLocation(), "Unsupported token [" + p.currentToken() + "]"); + }, NORMALIZER_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING); + } + + public static LinearRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { + return PARSER.apply(parser, context); } } From d961f2228b39deb945168e1dd538519f3d2a6c74 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Thu, 23 Jan 2025 15:44:48 +0200 Subject: [PATCH 47/57] updating parsing to use a static parser --- .../xpack/rank/linear/LinearRetrieverIT.java | 4 ++-- .../xpack/rank/linear/LinearRetrieverComponent.java | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java index 2359a4dbc5758..f98231a647470 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java @@ -387,7 +387,7 @@ public void testMultipleLinearRetrievers() { ), rankWindowSize, new float[] { 2.0f, 1.0f }, - null + new ScoreNormalizer[] { IdentityScoreNormalizer.INSTANCE, IdentityScoreNormalizer.INSTANCE } ), null ), @@ -484,7 +484,7 @@ public void testLinearExplainWithNamedRetrievers() { rrfDetails.getDescription(), equalTo( "weighted linear combination score: [30.0] computed for normalized scores [9.0, 20.0, 1.0] " - + "as sum of (weight[i] * score[i]) for each query." + + "and weights [1.0, 1.0, 1.0] as sum of (weight[i] * score[i]) for each query." ) ); diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java index 23035b98f5a86..f0b6dd519afe3 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java @@ -21,7 +21,6 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; -import static org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder.RETRIEVERS_FIELD; public class LinearRetrieverComponent implements ToXContentObject { @@ -45,7 +44,9 @@ public LinearRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight, @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.field(RETRIEVERS_FIELD.getPreferredName(), retriever); + builder.field(RETRIEVER_FIELD.getPreferredName(), retriever); + builder.field(WEIGHT_FIELD.getPreferredName(), weight); + builder.field(NORMALIZER_FIELD.getPreferredName(), normalizer); return builder; } @@ -73,7 +74,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return ScoreNormalizer.valueOf(p.text()); } else if (p.currentToken() == XContentParser.Token.START_OBJECT) { p.nextToken(); - return ScoreNormalizer.parse(p.currentName(), p); + ScoreNormalizer normalizer = ScoreNormalizer.parse(p.currentName(), p); + p.nextToken(); + return normalizer; } throw new ParsingException(p.getTokenLocation(), "Unsupported token [" + p.currentToken() + "]"); }, NORMALIZER_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING); From 3640ae14428f7332649ef5a48c3612183db63944 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Fri, 24 Jan 2025 08:35:56 +0200 Subject: [PATCH 48/57] avoid populating LinearRankDoc metadata if not explain --- .../xpack/rank/linear/LinearRankDoc.java | 8 ++++- .../rank/linear/LinearRetrieverBuilder.java | 30 ++++++++++++------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java index a1b8943517f93..bb1c420bbd06c 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRankDoc.java @@ -31,6 +31,12 @@ public class LinearRankDoc extends RankDoc { final String[] normalizers; public float[] normalizedScores; + public LinearRankDoc(int doc, float score, int shardIndex) { + super(doc, score, shardIndex); + this.weights = null; + this.normalizers = null; + } + public LinearRankDoc(int doc, float score, int shardIndex, float[] weights, String[] normalizers) { super(doc, score, shardIndex); this.weights = weights; @@ -46,7 +52,7 @@ public LinearRankDoc(StreamInput in) throws IOException { @Override public Explanation explain(Explanation[] sources, String[] queryNames) { - assert normalizedScores != null; + assert normalizedScores != null && weights != null && normalizers != null; assert normalizedScores.length == sources.length; Explanation[] details = new Explanation[sources.length]; diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index 308dc7d3881df..6068b6f118882 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -140,7 +140,7 @@ protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBu } @Override - protected RankDoc[] combineInnerRetrieverResults(List rankResults) { + protected RankDoc[] combineInnerRetrieverResults(List rankResults, boolean isExplain) { Map docsToRankResults = Maps.newMapWithExpectedSize(rankWindowSize); final String[] normalizerNames = Arrays.stream(normalizers).map(ScoreNormalizer::getName).toArray(String[]::new); for (int result = 0; result < rankResults.size(); result++) { @@ -154,14 +154,25 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { new RankDoc.RankKey(originalScoreDocs[scoreDocIndex].doc, originalScoreDocs[scoreDocIndex].shardIndex), (key, value) -> { if (value == null) { - value = new LinearRankDoc( - originalScoreDocs[finalScoreIndex].doc, - 0f, - originalScoreDocs[finalScoreIndex].shardIndex, - weights, - normalizerNames - ); - value.normalizedScores = new float[rankResults.size()]; + if (isExplain) { + value = new LinearRankDoc( + originalScoreDocs[finalScoreIndex].doc, + 0f, + originalScoreDocs[finalScoreIndex].shardIndex, + weights, + normalizerNames + ); + value.normalizedScores = new float[rankResults.size()]; + } else { + value = new LinearRankDoc( + originalScoreDocs[finalScoreIndex].doc, + 0f, + originalScoreDocs[finalScoreIndex].shardIndex + ); + } + } + if (isExplain) { + value.normalizedScores[finalResult] = normalizedScoreDocs[finalScoreIndex].score; } // if we do not have scores associated with this result set, just ignore its contribution to the final // score computation by setting its score to 0. @@ -169,7 +180,6 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { ? normalizedScoreDocs[finalScoreIndex].score : DEFAULT_SCORE; final float weight = Float.isNaN(weights[finalResult]) ? DEFAULT_WEIGHT : weights[finalResult]; - value.normalizedScores[finalResult] = normalizedScoreDocs[finalScoreIndex].score; value.score += weight * docScore; return value; } From 92591592295742eec9878fb917ec2fdb203741a4 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Mon, 27 Jan 2025 19:06:47 +0200 Subject: [PATCH 49/57] addressing PR comments - simplifying linear score computation --- .../rank/linear/LinearRetrieverBuilder.java | 48 +++++++------------ 1 file changed, 17 insertions(+), 31 deletions(-) diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index 6068b6f118882..15d790bc56242 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -148,42 +148,28 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults, b ScoreDoc[] originalScoreDocs = rankResults.get(result); ScoreDoc[] normalizedScoreDocs = normalizer.normalizeScores(originalScoreDocs); for (int scoreDocIndex = 0; scoreDocIndex < normalizedScoreDocs.length; scoreDocIndex++) { - int finalResult = result; - int finalScoreIndex = scoreDocIndex; - docsToRankResults.compute( + LinearRankDoc rankDoc = docsToRankResults.computeIfAbsent( new RankDoc.RankKey(originalScoreDocs[scoreDocIndex].doc, originalScoreDocs[scoreDocIndex].shardIndex), - (key, value) -> { - if (value == null) { - if (isExplain) { - value = new LinearRankDoc( - originalScoreDocs[finalScoreIndex].doc, - 0f, - originalScoreDocs[finalScoreIndex].shardIndex, - weights, - normalizerNames - ); - value.normalizedScores = new float[rankResults.size()]; - } else { - value = new LinearRankDoc( - originalScoreDocs[finalScoreIndex].doc, - 0f, - originalScoreDocs[finalScoreIndex].shardIndex - ); - } - } + key -> { if (isExplain) { - value.normalizedScores[finalResult] = normalizedScoreDocs[finalScoreIndex].score; + LinearRankDoc doc = new LinearRankDoc(key.doc(), 0f, key.shardIndex(), weights, normalizerNames); + doc.normalizedScores = new float[rankResults.size()]; + return doc; + } else { + return new LinearRankDoc(key.doc(), 0f, key.shardIndex()); } - // if we do not have scores associated with this result set, just ignore its contribution to the final - // score computation by setting its score to 0. - final float docScore = false == Float.isNaN(normalizedScoreDocs[finalScoreIndex].score) - ? normalizedScoreDocs[finalScoreIndex].score - : DEFAULT_SCORE; - final float weight = Float.isNaN(weights[finalResult]) ? DEFAULT_WEIGHT : weights[finalResult]; - value.score += weight * docScore; - return value; } ); + if (isExplain) { + rankDoc.normalizedScores[result] = normalizedScoreDocs[scoreDocIndex].score; + } + // if we do not have scores associated with this result set, just ignore its contribution to the final + // score computation by setting its score to 0. + final float docScore = false == Float.isNaN(normalizedScoreDocs[scoreDocIndex].score) + ? normalizedScoreDocs[scoreDocIndex].score + : DEFAULT_SCORE; + final float weight = Float.isNaN(weights[result]) ? DEFAULT_WEIGHT : weights[result]; + rankDoc.score += weight * docScore; } } // sort the results based on the final score, tiebreaker based on smaller doc id From ea1787fe5aa015cc466c2c1bcf9b893d5a870e77 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Mon, 27 Jan 2025 19:12:52 +0200 Subject: [PATCH 50/57] addressing PR comments - adding yaml test for linear retriever with interleaved results --- .../test/linear/10_linear_retriever.yml | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml index d82d2f4d41c4f..9ccd6b1888ca4 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml @@ -83,6 +83,101 @@ setup: - match: { hits.hits.1._id: "4" } - match: { hits.hits.1._score: 2.0 } +--- +"basic linear weighted combination - interleaved results": + - do: + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + # this one will return docs 1 and doc 2 with scores 20 and 10 respectively + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "two" + } + } + }, + boost: 5.0 + } + } + ] + } + } + }, + weight: 2 + }, + { + # this one will return docs 3 and doc 4 with scores 15 and 12 respectively + retriever: { + standard: { + query: { + bool: { + should: [ + { + constant_score: { + filter: { + term: { + keyword: { + value: "three" + } + } + }, + boost: 5.0 + } + }, + { + constant_score: { + filter: { + term: { + keyword: { + value: "four" + } + } + }, + boost: 4.0 + } + } + ] + } + } + }, + weight: 3 + } + ] + + - match: { hits.total.value: 4 } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.0._score: 20.0 } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.1._score: 15.0 } + - match: { hits.hits.2._id: "4" } + - match: { hits.hits.2._score: 12.0 } + - match: { hits.hits.3._id: "2" } + - match: { hits.hits.3._score: 10.0 } + --- "should normalize initial scores": - do: From ce8f60f5c27acee92daa02910a7445d0cfd35be0 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Mon, 27 Jan 2025 20:52:04 +0200 Subject: [PATCH 51/57] removing custom min max options for normalizer --- .../rank/linear/LinearRetrieverComponent.java | 18 +- .../rank/linear/MinMaxScoreNormalizer.java | 76 ++------ .../xpack/rank/linear/ScoreNormalizer.java | 10 -- .../LinearRetrieverBuilderParsingTests.java | 8 +- .../test/linear/10_linear_retriever.yml | 170 ++++-------------- 5 files changed, 59 insertions(+), 223 deletions(-) diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java index f0b6dd519afe3..e78e1c71eaa62 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.rank.linear; -import org.elasticsearch.common.ParsingException; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.xcontent.ConstructingObjectParser; @@ -69,17 +68,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return innerRetriever; }, RETRIEVER_FIELD); PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD); - PARSER.declareField(optionalConstructorArg(), (p, c) -> { - if (p.currentToken() == XContentParser.Token.VALUE_STRING) { - return ScoreNormalizer.valueOf(p.text()); - } else if (p.currentToken() == XContentParser.Token.START_OBJECT) { - p.nextToken(); - ScoreNormalizer normalizer = ScoreNormalizer.parse(p.currentName(), p); - p.nextToken(); - return normalizer; - } - throw new ParsingException(p.getTokenLocation(), "Unsupported token [" + p.currentToken() + "]"); - }, NORMALIZER_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> ScoreNormalizer.valueOf(p.text()), + NORMALIZER_FIELD, + ObjectParser.ValueType.STRING + ); } public static LinearRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java index 8f5865dcd9fea..5726caeb58864 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java @@ -9,49 +9,21 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; - -import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; - public class MinMaxScoreNormalizer extends ScoreNormalizer { public static final String NAME = "minmax"; - public static final ParseField MIN_FIELD = new ParseField("min"); - public static final ParseField MAX_FIELD = new ParseField("max"); - private static final float EPSILON = 1e-6f; - public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { - Float min = (Float) args[0]; - Float max = (Float) args[1]; - return new MinMaxScoreNormalizer(min, max); - }); - - static { - PARSER.declareFloat(optionalConstructorArg(), MIN_FIELD); - PARSER.declareFloat(optionalConstructorArg(), MAX_FIELD); - } + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + args -> new MinMaxScoreNormalizer() + ); - private Float min; - private Float max; - - public MinMaxScoreNormalizer() { - this.min = null; - this.max = null; - } - - public MinMaxScoreNormalizer(Float min, Float max) { - if (min != null && max != null && min >= max) { - throw new IllegalArgumentException("[min] must be less than [max]"); - } - this.min = min; - this.max = max; - } + public MinMaxScoreNormalizer() {} @Override public String getName() { @@ -65,19 +37,18 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { } // create a new array to avoid changing ScoreDocs in place ScoreDoc[] scoreDocs = new ScoreDoc[docs.length]; - float correction = 0f; - float xMin = Float.MAX_VALUE; - float xMax = Float.MIN_VALUE; + float min = Float.MAX_VALUE; + float max = Float.MIN_VALUE; boolean atLeastOneValidScore = false; for (ScoreDoc rd : docs) { if (false == atLeastOneValidScore && false == Float.isNaN(rd.score)) { atLeastOneValidScore = true; } - if (rd.score > xMax) { - xMax = rd.score; + if (rd.score > max) { + max = rd.score; } - if (rd.score < xMin) { - xMin = rd.score; + if (rd.score < min) { + min = rd.score; } } if (false == atLeastOneValidScore) { @@ -85,27 +56,13 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { return docs; } - if (min == null) { - min = xMin; - } else { - if (min > xMin) { - correction = min - xMin; - } - } - if (max == null) { - max = xMax; - } - - if (min > max) { - throw new IllegalArgumentException("[min=" + min + "] must be less than [max=" + max + "]"); - } boolean minEqualsMax = Math.abs(min - max) < EPSILON; for (int i = 0; i < docs.length; i++) { float score; if (minEqualsMax) { score = min; } else { - score = correction + (docs[i].score - min) / (max - min); + score = (docs[i].score - min) / (max - min); } scoreDocs[i] = new ScoreDoc(docs[i].doc, score, docs[i].shardIndex); } @@ -117,12 +74,7 @@ public static MinMaxScoreNormalizer fromXContent(XContentParser parser) { } @Override - public void doToXContent(XContentBuilder builder, Params params) throws IOException { - if (min != null) { - builder.field(MIN_FIELD.getPreferredName(), min); - } - if (max != null) { - builder.field(MAX_FIELD.getPreferredName(), max); - } + public void doToXContent(XContentBuilder builder, Params params) { + // no-op } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java index 137b3aa2b8446..c6cfb7c1413b2 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java @@ -10,7 +10,6 @@ import org.apache.lucene.search.ScoreDoc; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; @@ -30,15 +29,6 @@ public static ScoreNormalizer valueOf(String normalizer) { } } - public static ScoreNormalizer parse(String normalizer, XContentParser p) { - if (MinMaxScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) { - return MinMaxScoreNormalizer.fromXContent(p); - } else if (IdentityScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) { - return IdentityScoreNormalizer.fromXContent(p); - } - throw new IllegalArgumentException("Unknown normalizer [" + normalizer + "]"); - } - protected abstract void doToXContent(XContentBuilder builder, Params params) throws IOException; public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java index fc2643d3a5c79..adba86bee7226 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java @@ -92,8 +92,10 @@ protected NamedXContentRegistry xContentRegistry() { } private static ScoreNormalizer randomScoreNormalizer() { - Float min = frequently() ? randomFloat() : null; - Float max = frequently() && min != null ? min + randomFloat() : null; - return new MinMaxScoreNormalizer(min, max); + if (randomBoolean()) { + return new MinMaxScoreNormalizer(); + } else { + return new IdentityScoreNormalizer(); + } } } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml index 9ccd6b1888ca4..a265269c03870 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml @@ -125,69 +125,13 @@ setup: } ] } - } - }, - weight: 2 + } + }, + weight: 2 + } }, - { - # this one will return docs 3 and doc 4 with scores 15 and 12 respectively - retriever: { - standard: { - query: { - bool: { - should: [ - { - constant_score: { - filter: { - term: { - keyword: { - value: "three" - } - } - }, - boost: 5.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "four" - } - } - }, - boost: 4.0 - } - } - ] - } - } - }, - weight: 3 - } - ] - - - match: { hits.total.value: 4 } - - match: { hits.hits.0._id: "1" } - - match: { hits.hits.0._score: 20.0 } - - match: { hits.hits.1._id: "3" } - - match: { hits.hits.1._score: 15.0 } - - match: { hits.hits.2._id: "4" } - - match: { hits.hits.2._score: 12.0 } - - match: { hits.hits.3._id: "2" } - - match: { hits.hits.3._score: 10.0 } - ---- -"should normalize initial scores": - - do: - search: - index: test - body: - retriever: - linear: - retrievers: [ { + # this one will return docs 3 and doc 4 with scores 15 and 12 respectively retriever: { standard: { query: { @@ -198,23 +142,11 @@ setup: filter: { term: { keyword: { - value: "one" - } - } - }, - boost: 10.0 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "two" + value: "three" } } }, - boost: 9.0 + boost: 5.0 } }, { @@ -222,47 +154,34 @@ setup: filter: { term: { keyword: { - value: "three" + value: "four" } } }, - boost: 5.0 + boost: 4.0 } } ] } } - } - }, - weight: 10.0, - normalizer: "minmax" - }, - { - retriever: { - knn: { - field: "vector", - query_vector: [ 4 ], - k: 1, - num_candidates: 1 - } - }, - weight: 2.0 + }, + weight: 3 + } } ] - match: { hits.total.value: 4 } - match: { hits.hits.0._id: "1" } - - match: {hits.hits.0._score: 10.0} - - match: { hits.hits.1._id: "2" } - - match: {hits.hits.1._score: 8.0} + - match: { hits.hits.0._score: 20.0 } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.1._score: 15.0 } - match: { hits.hits.2._id: "4" } - - match: {hits.hits.2._score: 2.0} - - match: { hits.hits.2._score: 2.0 } - - match: { hits.hits.3._id: "3" } - - close_to: { hits.hits.3._score: { value: 0.0, error: 0.001 } } + - match: { hits.hits.2._score: 12.0 } + - match: { hits.hits.3._id: "2" } + - match: { hits.hits.3._score: 10.0 } --- -"should normalize initial scores with a custom minmax normalizer": +"should normalize initial scores": - do: search: index: test @@ -285,7 +204,7 @@ setup: } } }, - boost: 10.0 # normalized score for this would be -0.55 + boost: 10.0 } }, { @@ -297,7 +216,7 @@ setup: } } }, - boost: 9.0 # normalized score for this would be -0.56 + boost: 9.0 } }, { @@ -309,19 +228,7 @@ setup: } } }, - boost: 5.0 # normalized score for this would be -0.63 - } - }, - { - constant_score: { - filter: { - term: { - keyword: { - value: "four" - } - } - }, - boost: 1.0 # normalized score for this would be -0.7 + boost: 5.0 } } ] @@ -329,15 +236,10 @@ setup: } } }, - normalizer: { - minmax: { - min: 42, - max: 100 - } - } + weight: 10.0, + normalizer: "minmax" }, { - # this only provides a score of 10 for doc 4 retriever: { knn: { field: "vector", @@ -346,19 +248,20 @@ setup: num_candidates: 1 } }, - weight: 10.0 + weight: 2.0 } ] - match: { hits.total.value: 4 } - - match: { hits.hits.0._id: "4" } - - close_to: { hits.hits.0._score: { value: 50.2931, error: 0.001 } } - - match: { hits.hits.1._id: "1" } - - close_to: { hits.hits.1._score: { value: 40.4482, error: 0.001 } } - - match: { hits.hits.2._id: "2" } - - close_to: { hits.hits.2._score: { value: 40.4310, error: 0.001 } } + - match: { hits.hits.0._id: "1" } + - match: {hits.hits.0._score: 10.0} + - match: { hits.hits.1._id: "2" } + - match: {hits.hits.1._score: 8.0} + - match: { hits.hits.2._id: "4" } + - match: {hits.hits.2._score: 2.0} + - match: { hits.hits.2._score: 2.0 } - match: { hits.hits.3._id: "3" } - - close_to: { hits.hits.3._score: { value: 40.3620, error: 0.001 } } + - close_to: { hits.hits.3._score: { value: 0.0, error: 0.001 } } --- "should throw on unknown normalizer": @@ -1104,12 +1007,7 @@ setup: } }, weight: 1.0, - normalizer: { - minmax: { - min: 1577836800000, # 2020-01-01T00:00:00 - max: 1735689600000 # 2025-01-01T00:00:00 - } - } + normalizer: "minmax" } ] rank_window_size: 2 From 2bda448e454003cd9e413132f1064712bc9005bf Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Mon, 27 Jan 2025 20:57:39 +0200 Subject: [PATCH 52/57] adding assertion for negative weights --- docs/reference/rest-api/common-parms.asciidoc | 13 +------------ .../search-your-data/retrievers-examples.asciidoc | 13 ++----------- .../xpack/rank/linear/LinearRetrieverComponent.java | 3 +++ 3 files changed, 6 insertions(+), 23 deletions(-) diff --git a/docs/reference/rest-api/common-parms.asciidoc b/docs/reference/rest-api/common-parms.asciidoc index 59ba50942ab41..6d80db2e50538 100644 --- a/docs/reference/rest-api/common-parms.asciidoc +++ b/docs/reference/rest-api/common-parms.asciidoc @@ -1385,24 +1385,13 @@ Specifies how we will normalize the retriever's scores, before applying the spec We can either provide a string reference to use with the default values or further configure any normalizer using its specific properties. Available values are: `minmax`, and `none`. Defaults to `none`. -** `none` : takes no argument +** `none` ** `minmax` : A `MinMaxScoreNormalizer` that normalizes scores based on the following formula + ``` score = (score - min) / (max - min) ``` -Available properties are: -*** `min`:: -(Optional, float) -+ -The minimum value of the original scores. Defaults to result set's true min value. - -*** `max`:: -(Optional, float) -+ -The maximum value of the original scores. Defaults to result set's true max value. - See also <> using a linear retriever on how to independently configure and apply normalizers to retrievers. diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index c195ca62bbf79..b69e56573b328 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -252,12 +252,7 @@ GET /retrievers_example/_search } }, "weight": 1.5, - "normalizer": { - "minmax": { - "min": 0.5, - "max": 1.0 - } - } + "normalizer": "minmax" } ], "rank_window_size": 10 @@ -359,11 +354,7 @@ GET /retrievers_example/_search } }, "weight": 2, - "normalizer": { - "minmax": { - "min": "1483228800000" - } - } + "normalizer": "minmax" }, { "retriever": { diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java index e78e1c71eaa62..f1975307a7f7e 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java @@ -39,6 +39,9 @@ public LinearRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight, this.retriever = retrieverBuilder; this.weight = weight == null ? DEFAULT_WEIGHT : weight; this.normalizer = normalizer == null ? DEFAULT_NORMALIZER : normalizer; + if (this.weight < 0) { + throw new IllegalArgumentException("[weight] must be non-negative"); + } } @Override From 8b07ea577dafb55a7a8afd9b259927b2bc98df9e Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Mon, 27 Jan 2025 23:37:23 +0200 Subject: [PATCH 53/57] updating tests after latest changes --- .../retrievers-examples.asciidoc | 4 +-- .../rank/linear/IdentityScoreNormalizer.java | 21 ------------ .../rank/linear/LinearRetrieverBuilder.java | 2 +- .../rank/linear/LinearRetrieverComponent.java | 2 +- .../rank/linear/MinMaxScoreNormalizer.java | 19 ++--------- .../xpack/rank/linear/ScoreNormalizer.java | 22 ++----------- .../LinearRetrieverBuilderParsingTests.java | 4 +-- .../test/linear/10_linear_retriever.yml | 32 +++++++++---------- 8 files changed, 27 insertions(+), 79 deletions(-) diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index b69e56573b328..437eea44b3975 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -415,12 +415,12 @@ Which would return the following results: }, { "_index": "retrievers_example", - "_id": "1", + "_id": "4", "_score": -3 }, { "_index": "retrievers_example", - "_id": "4", + "_id": "1", "_score": -4 } ] diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java index 7b1f70b821b14..15af17a1db4ef 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/IdentityScoreNormalizer.java @@ -8,11 +8,6 @@ package org.elasticsearch.xpack.rank.linear; import org.apache.lucene.search.ScoreDoc; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; - -import java.io.IOException; public class IdentityScoreNormalizer extends ScoreNormalizer { @@ -20,13 +15,6 @@ public class IdentityScoreNormalizer extends ScoreNormalizer { public static final String NAME = "none"; - public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { - if (args.length != 0) { - throw new IllegalArgumentException("[IdentityScoreNormalizer] does not accept any arguments"); - } - return new IdentityScoreNormalizer(); - }); - @Override public String getName() { return NAME; @@ -36,13 +24,4 @@ public String getName() { public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { return docs; } - - public static IdentityScoreNormalizer fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); - } - - @Override - public void doToXContent(XContentBuilder builder, Params params) throws IOException { - // no-op - } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java index 15d790bc56242..66bbbf95bc9d6 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java @@ -197,7 +197,7 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept builder.startObject(); builder.field(LinearRetrieverComponent.RETRIEVER_FIELD.getPreferredName(), entry.retriever()); builder.field(LinearRetrieverComponent.WEIGHT_FIELD.getPreferredName(), weights[index]); - builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index]); + builder.field(LinearRetrieverComponent.NORMALIZER_FIELD.getPreferredName(), normalizers[index].getName()); builder.endObject(); index++; } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java index f1975307a7f7e..bb0d79d3fe488 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverComponent.java @@ -48,7 +48,7 @@ public LinearRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight, public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.field(RETRIEVER_FIELD.getPreferredName(), retriever); builder.field(WEIGHT_FIELD.getPreferredName(), weight); - builder.field(NORMALIZER_FIELD.getPreferredName(), normalizer); + builder.field(NORMALIZER_FIELD.getPreferredName(), normalizer.getName()); return builder; } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java index 5726caeb58864..56b42b48a5d47 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/MinMaxScoreNormalizer.java @@ -8,21 +8,15 @@ package org.elasticsearch.xpack.rank.linear; import org.apache.lucene.search.ScoreDoc; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentParser; public class MinMaxScoreNormalizer extends ScoreNormalizer { + public static final MinMaxScoreNormalizer INSTANCE = new MinMaxScoreNormalizer(); + public static final String NAME = "minmax"; private static final float EPSILON = 1e-6f; - public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - NAME, - args -> new MinMaxScoreNormalizer() - ); - public MinMaxScoreNormalizer() {} @Override @@ -68,13 +62,4 @@ public ScoreDoc[] normalizeScores(ScoreDoc[] docs) { } return scoreDocs; } - - public static MinMaxScoreNormalizer fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); - } - - @Override - public void doToXContent(XContentBuilder builder, Params params) { - // no-op - } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java index c6cfb7c1413b2..48334b9adf957 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/ScoreNormalizer.java @@ -8,39 +8,23 @@ package org.elasticsearch.xpack.rank.linear; import org.apache.lucene.search.ScoreDoc; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; /** * A no-op {@link ScoreNormalizer} that does not modify the scores. */ -public abstract class ScoreNormalizer implements ToXContent { +public abstract class ScoreNormalizer { public static ScoreNormalizer valueOf(String normalizer) { if (MinMaxScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) { - return new MinMaxScoreNormalizer(); + return MinMaxScoreNormalizer.INSTANCE; } else if (IdentityScoreNormalizer.NAME.equalsIgnoreCase(normalizer)) { - return new IdentityScoreNormalizer(); + return IdentityScoreNormalizer.INSTANCE; } else { throw new IllegalArgumentException("Unknown normalizer [" + normalizer + "]"); } } - protected abstract void doToXContent(XContentBuilder builder, Params params) throws IOException; - - public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.startObject(getName()); - doToXContent(builder, params); - builder.endObject(); - builder.endObject(); - - return builder; - } - public abstract String getName(); public abstract ScoreDoc[] normalizeScores(ScoreDoc[] docs); diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java index adba86bee7226..5cc66c6f50d3c 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderParsingTests.java @@ -93,9 +93,9 @@ protected NamedXContentRegistry xContentRegistry() { private static ScoreNormalizer randomScoreNormalizer() { if (randomBoolean()) { - return new MinMaxScoreNormalizer(); + return MinMaxScoreNormalizer.INSTANCE; } else { - return new IdentityScoreNormalizer(); + return IdentityScoreNormalizer.INSTANCE; } } } diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml index a265269c03870..2704ee15a89ef 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml @@ -126,9 +126,9 @@ setup: ] } } - }, - weight: 2 - } + } + }, + weight: 2 }, { # this one will return docs 3 and doc 4 with scores 15 and 12 respectively @@ -164,9 +164,9 @@ setup: ] } } - }, - weight: 3 - } + } + }, + weight: 3 } ] @@ -291,9 +291,7 @@ setup: } }, weight: 1.0, - normalizer: { - aardvark: { } - } + normalizer: "aardvark" }, { retriever: { @@ -907,14 +905,14 @@ setup: { term: { keyword: { - value: "one" # this will give doc 1 a normalized score of 10 + value: "one" # this will give doc 1 a normalized score of 10 because min == max } } }, { term: { keyword: { - value: "two" # this will give doc 2 a normalized score of 10 + value: "two" # this will give doc 2 a normalized score of 10 because min == max } } } ] @@ -934,15 +932,17 @@ setup: normalizer: "minmax" }, { - # because we're sorting on timestamp and use a rank window size of 2, we will only get to see + # because we're sorting on timestamp and use a rank window size of 3, we will only get to see # docs 3 and 2. # their `scores` (which are the timestamps) are: # doc 3: 1672531200000 (2023-01-01T00:00:00) # doc 2: 1640995200000 (2022-01-01T00:00:00) + # doc 1: 1609459200000 (2021-01-01T00:00:00) # and their normalized scores based on the provided conf # will be: - # normalized(doc3) = 0.59989 - # normalized(doc2) = 0.40010 + # normalized(doc3) = 1. + # normalized(doc2) = 0.5 + # normalized(doc1) = 0 retriever: { standard: { query: { @@ -1010,12 +1010,12 @@ setup: normalizer: "minmax" } ] - rank_window_size: 2 + rank_window_size: 3 size: 2 - match: { hits.total.value: 3 } - length: {hits.hits: 2} - match: { hits.hits.0._id: "2" } - - close_to: { hits.hits.0._score: { value: 10.4001, error: 0.001 } } + - close_to: { hits.hits.0._score: { value: 10.5, error: 0.001 } } - match: { hits.hits.1._id: "1" } - match: { hits.hits.1._score: 10 } From 3ba0587dd62dfe9ae5f0253d3c4adf8bf1bb6da5 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 28 Jan 2025 00:01:32 +0200 Subject: [PATCH 54/57] Update common-parms.asciidoc --- docs/reference/rest-api/common-parms.asciidoc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/reference/rest-api/common-parms.asciidoc b/docs/reference/rest-api/common-parms.asciidoc index 6d80db2e50538..37c5528812900 100644 --- a/docs/reference/rest-api/common-parms.asciidoc +++ b/docs/reference/rest-api/common-parms.asciidoc @@ -1376,14 +1376,13 @@ results, which will later be merged based on the specified `weight` and `normali * `weight`:: (Optional, float) + -The weight that each score of this retriever's top docs will be multiplied with. Defaults to 1.0. +The weight that each score of this retriever's top docs will be multiplied with. Must be greater or equal to 0. Defaults to 1.0. * `normalizer`:: -(Optional, String or Object) +(Optional, String) + Specifies how we will normalize the retriever's scores, before applying the specified `weight`. -We can either provide a string reference to use with the default values or further configure any normalizer -using its specific properties. Available values are: `minmax`, and `none`. Defaults to `none`. +Available values are: `minmax`, and `none`. Defaults to `none`. ** `none` ** `minmax` : From 3237ef573c2203011237b12a2897ea7ab69900b1 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 28 Jan 2025 00:10:21 +0200 Subject: [PATCH 55/57] Update retrievers-examples.asciidoc --- .../search/search-your-data/retrievers-examples.asciidoc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/reference/search/search-your-data/retrievers-examples.asciidoc b/docs/reference/search/search-your-data/retrievers-examples.asciidoc index 437eea44b3975..bc5f891a759b6 100644 --- a/docs/reference/search/search-your-data/retrievers-examples.asciidoc +++ b/docs/reference/search/search-your-data/retrievers-examples.asciidoc @@ -314,7 +314,7 @@ This returns the following response based on the normalized weighted score for e By normalizing scores and leveraging `function_score` queries, we can also implement more complex ranking strategies, such as sorting results based on their timestamps, assign the timestamp as a score, and then normalizing this score to -[0, 1] range where 1 is `today` and `0` is the oldest reference document in the index. +[0, 1]. Then, we can easily combine the above with a `knn` retriever as follows: [source,console] From 173f254ee2531a5e029d94e69bfc639db87b853e Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 28 Jan 2025 12:15:31 +0200 Subject: [PATCH 56/57] setting knn field to flat --- .../resources/rest-api-spec/test/linear/10_linear_retriever.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml index 2704ee15a89ef..6ab6bce130a44 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml @@ -15,6 +15,8 @@ setup: dims: 1 index: true similarity: l2_norm + index_options: + type: flat keyword: type: keyword other_keyword: From 42c543ad1d551d5d6dcf7e57759d7c541f771327 Mon Sep 17 00:00:00 2001 From: Panagiotis Bailis Date: Tue, 28 Jan 2025 14:35:06 +0200 Subject: [PATCH 57/57] adding ids to parameter sections for retriever docs --- docs/reference/search/retriever.asciidoc | 14 +++++-- .../test/linear/10_linear_retriever.yml | 42 +++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/docs/reference/search/retriever.asciidoc b/docs/reference/search/retriever.asciidoc index 547219a2bc924..200a8ae47be1c 100644 --- a/docs/reference/search/retriever.asciidoc +++ b/docs/reference/search/retriever.asciidoc @@ -48,6 +48,8 @@ A <> that applies contextual <> to pin o A standard retriever returns top documents from a traditional <>. +[discrete] +[[standard-retriever-parameters]] ===== Parameters: `query`:: @@ -198,6 +200,8 @@ Documents matching these conditions will have increased relevancy scores. A kNN retriever returns top documents from a <>. +[discrete] +[[knn-retriever-parameters]] ===== Parameters `field`:: @@ -270,10 +274,10 @@ This value must be fewer than or equal to `num_candidates`. [[linear-retriever]] ==== Linear Retriever -A retriever that normalizes and linearly combines the scores of other retrievers. If the final scores produced after the -weighted combination of all sub-retrievers are negative, a corrective factor is applied equal to the minimum score, -so all scores are positive. +A retriever that normalizes and linearly combines the scores of other retrievers. +[discrete] +[[linear-retriever-parameters]] ===== Parameters include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=linear-retriever-components] @@ -288,6 +292,8 @@ include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=compound-retriever-filt An <> retriever returns top documents based on the RRF formula, equally weighting two or more child retrievers. Reciprocal rank fusion (RRF) is a method for combining multiple result sets with different relevance indicators into a single result set. +[discrete] +[[rrf-retriever-parameters]] ===== Parameters include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=rrf-retrievers] @@ -540,6 +546,8 @@ You have the following options: ** Then set up an <> with the `rerank` task type. ** Refer to the <> on this page for a step-by-step guide. +[discrete] +[[text-similarity-reranker-retriever-parameters]] ===== Parameters `retriever`:: diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml index 6ab6bce130a44..70db6c1543365 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/linear/10_linear_retriever.yml @@ -308,6 +308,48 @@ setup: } ] +--- +"should throw on negative weights": + - do: + catch: /\[weight\] must be non-negative/ + search: + index: test + body: + retriever: + linear: + retrievers: [ + { + retriever: { + standard: { + query: { + constant_score: { + filter: { + term: { + keyword: { + value: "one" + } + } + }, + boost: 10.0 + } + } + } + }, + weight: 1.0 + }, + { + retriever: { + knn: { + field: "vector", + query_vector: [ 4 ], + k: 1, + num_candidates: 1 + } + }, + weight: -10 + } + ] + --- "pagination within a consistent rank_window_size": - do: