diff --git a/docs/changelog/130658.yaml b/docs/changelog/130658.yaml new file mode 100644 index 0000000000000..c5075b70db5b5 --- /dev/null +++ b/docs/changelog/130658.yaml @@ -0,0 +1,5 @@ +pr: 130658 +summary: Add support for weighted RRF in retrievers +area: Relevance +type: enhancement +issues: [] diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java index f76c22fe1344e..326a2f276fa6a 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java @@ -36,7 +36,8 @@ public Set getTestFeatures() { LINEAR_RETRIEVER_L2_NORM, LINEAR_RETRIEVER_MINSCORE_FIX, LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT, - RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT + RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT, + RRFRetrieverBuilder.WEIGHTED_SUPPORT ); } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index c05fd3bd0a11f..702bb0df0f9eb 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -20,6 +20,7 @@ import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; +import org.elasticsearch.search.retriever.CompoundRetrieverBuilder.RetrieverSource; import org.elasticsearch.search.retriever.RetrieverBuilder; import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.retriever.StandardRetrieverBuilder; @@ -37,7 +38,7 @@ import java.util.Map; import java.util.Objects; -import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverComponent.DEFAULT_WEIGHT; /** * An rrf retriever is used to represent an rrf rank element, but @@ -48,6 +49,7 @@ */ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder { public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support"); + public static final NodeFeature WEIGHTED_SUPPORT = new NodeFeature("rrf_retriever.weighted_support"); public static final String NAME = "rrf"; @@ -57,37 +59,38 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder PARSER = new ConstructingObjectParser<>( NAME, false, args -> { - List childRetrievers = (List) args[0]; + List retrieverComponents = args[0] == null ? List.of() : (List) args[0]; List fields = (List) args[1]; String query = (String) args[2]; int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3]; int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4]; - List innerRetrievers = childRetrievers != null - ? childRetrievers.stream().map(RetrieverSource::from).toList() - : List.of(); - return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant); + int n = retrieverComponents.size(); + List innerRetrievers = new ArrayList<>(n); + float[] weights = new float[n]; + for (int i = 0; i < n; i++) { + RRFRetrieverComponent component = retrieverComponents.get(i); + innerRetrievers.add(RetrieverSource.from(component.retriever())); + weights[i] = component.weight(); + } + return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant, weights); } ); static { - PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> { - p.nextToken(); - String name = p.currentName(); - RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c); - c.trackRetrieverUsage(retrieverBuilder.getName()); - p.nextToken(); - return retrieverBuilder; - }, RETRIEVERS_FIELD); - PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD); - PARSER.declareString(optionalConstructorArg(), QUERY_FIELD); - PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); - PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD); + PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), RRFRetrieverComponent::fromXContent, RETRIEVERS_FIELD); + PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), FIELDS_FIELD); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), QUERY_FIELD); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_CONSTANT_FIELD); RetrieverBuilder.declareBaseParserFields(PARSER); } @@ -103,7 +106,14 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP private final int rankConstant; public RRFRetrieverBuilder(List childRetrievers, int rankWindowSize, int rankConstant) { - this(childRetrievers, null, null, rankWindowSize, rankConstant); + this(childRetrievers, null, null, rankWindowSize, rankConstant, createDefaultWeights(childRetrievers)); + } + + private static float[] createDefaultWeights(List retrievers) { + int size = retrievers == null ? 0 : retrievers.size(); + float[] defaultWeights = new float[size]; + Arrays.fill(defaultWeights, DEFAULT_WEIGHT); + return defaultWeights; } public RRFRetrieverBuilder( @@ -111,19 +121,31 @@ public RRFRetrieverBuilder( List fields, String query, int rankWindowSize, - int rankConstant + int rankConstant, + float[] weights ) { // Use a mutable list for childRetrievers so that we can use addChild super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize); this.fields = fields == null ? null : List.copyOf(fields); this.query = query; this.rankConstant = rankConstant; + Objects.requireNonNull(weights, "weights must not be null"); + if (weights.length != innerRetrievers.size()) { + throw new IllegalArgumentException( + "weights array length [" + weights.length + "] must match retrievers count [" + innerRetrievers.size() + "]" + ); + } + this.weights = weights; } public int rankConstant() { return rankConstant; } + public float[] weights() { + return weights; + } + @Override public String getName() { return NAME; @@ -137,6 +159,7 @@ public ActionRequestValidationException validate( boolean allowPartialSearchResults ) { validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults); + return MultiFieldsInnerRetrieverUtils.validateParams( innerRetrievers, fields, @@ -151,7 +174,14 @@ public ActionRequestValidationException validate( @Override protected RRFRetrieverBuilder clone(List newRetrievers, List newPreFilterQueryBuilders) { - RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant); + RRFRetrieverBuilder clone = new RRFRetrieverBuilder( + newRetrievers, + this.fields, + this.query, + this.rankWindowSize, + this.rankConstant, + this.weights + ); clone.preFilterQueryBuilders = newPreFilterQueryBuilders; clone.retrieverName = retrieverName; return clone; @@ -183,7 +213,7 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List rankResults // calculate the current rrf score for this document // later used to sort and covert to a rank - value.score += 1.0f / (rankConstant + frank); + value.score += this.weights[findex] * (1.0f / (rankConstant + frank)); if (explain && value.positions != null && value.scores != null) { // record the position for each query @@ -238,10 +268,14 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { query, localIndicesMetadata.values(), r -> { - List retrievers = r.stream() - .map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource) - .toList(); - return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant); + List retrievers = new ArrayList<>(r.size()); + float[] weights = new float[r.size()]; + for (int i = 0; i < r.size(); i++) { + var retriever = r.get(i); + retrievers.add(retriever.retrieverSource()); + weights[i] = retriever.weight(); + } + return new RRFRetrieverBuilder(retrievers, null, null, rankWindowSize, rankConstant, weights); }, w -> { if (w != 1.0f) { @@ -255,7 +289,8 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { if (fieldsInnerRetrievers.isEmpty() == false) { // TODO: This is a incomplete solution as it does not address other incomplete copy issues // (such as dropping the retriever name and min score) - rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant); + float[] weights = createDefaultWeights(fieldsInnerRetrievers); + rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, null, null, rankWindowSize, rankConstant, weights); rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); } else { // Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices @@ -266,29 +301,13 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) { return rewritten; } - // ---- FOR TESTING XCONTENT PARSING ---- - - @Override - public boolean doEquals(Object o) { - RRFRetrieverBuilder that = (RRFRetrieverBuilder) o; - return super.doEquals(o) - && Objects.equals(fields, that.fields) - && Objects.equals(query, that.query) - && rankConstant == that.rankConstant; - } - - @Override - public int doHashCode() { - return Objects.hash(super.doHashCode(), fields, query, rankConstant); - } - @Override public void doToXContent(XContentBuilder builder, Params params) throws IOException { if (innerRetrievers.isEmpty() == false) { builder.startArray(RETRIEVERS_FIELD.getPreferredName()); - - for (var entry : innerRetrievers) { - entry.retriever().toXContent(builder, params); + for (int i = 0; i < innerRetrievers.size(); i++) { + RRFRetrieverComponent component = new RRFRetrieverComponent(innerRetrievers.get(i).retriever(), weights[i]); + component.toXContent(builder, params); } builder.endArray(); } @@ -307,4 +326,20 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize); builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant); } + + // ---- FOR TESTING XCONTENT PARSING ---- + @Override + public boolean doEquals(Object o) { + RRFRetrieverBuilder that = (RRFRetrieverBuilder) o; + return super.doEquals(o) + && Objects.equals(fields, that.fields) + && Objects.equals(query, that.query) + && rankConstant == that.rankConstant + && Arrays.equals(weights, that.weights); + } + + @Override + public int doHashCode() { + return Objects.hash(super.doHashCode(), fields, query, rankConstant, Arrays.hashCode(weights)); + } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverComponent.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverComponent.java new file mode 100644 index 0000000000000..4946407fb19fb --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverComponent.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.rank.rrf; + +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.search.retriever.RetrieverBuilder; +import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class RRFRetrieverComponent implements ToXContentObject { + + public static final ParseField RETRIEVER_FIELD = new ParseField("retriever"); + public static final ParseField WEIGHT_FIELD = new ParseField("weight"); + static final float DEFAULT_WEIGHT = 1f; + + final RetrieverBuilder retriever; + final float weight; + + public RRFRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight) { + this.retriever = Objects.requireNonNull(retrieverBuilder, "retrieverBuilder must not be null"); + this.weight = weight == null ? DEFAULT_WEIGHT : weight; + if (this.weight < 0) { + throw new IllegalArgumentException("[weight] must be non-negative, found [" + this.weight + "]"); + } + } + + public RetrieverBuilder retriever() { + return retriever; + } + + public float weight() { + return weight; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Params params) throws IOException { + builder.startObject(); + builder.field(RETRIEVER_FIELD.getPreferredName(), retriever); + builder.field(WEIGHT_FIELD.getPreferredName(), weight); + builder.endObject(); + return builder; + } + + public static RRFRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { + throw new ParsingException(parser.getTokenLocation(), "expected object but found [{}]", parser.currentToken()); + } + + // Peek at the first field to determine the format + XContentParser.Token token = parser.nextToken(); + if (token == XContentParser.Token.END_OBJECT) { + throw new ParsingException(parser.getTokenLocation(), "retriever component must contain a retriever"); + } + if (token != XContentParser.Token.FIELD_NAME) { + throw new ParsingException(parser.getTokenLocation(), "expected field name but found [{}]", token); + } + + String firstFieldName = parser.currentName(); + + // Check if this is a structured component (starts with "retriever" or "weight") + if (RETRIEVER_FIELD.match(firstFieldName, parser.getDeprecationHandler()) + || WEIGHT_FIELD.match(firstFieldName, parser.getDeprecationHandler())) { + // This is a structured component - parse manually + RetrieverBuilder retriever = null; + Float weight = null; + + do { + String fieldName = parser.currentName(); + if (RETRIEVER_FIELD.match(fieldName, parser.getDeprecationHandler())) { + if (retriever != null) { + throw new ParsingException(parser.getTokenLocation(), "only one retriever can be specified"); + } + parser.nextToken(); + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { + throw new ParsingException(parser.getTokenLocation(), "retriever must be an object"); + } + parser.nextToken(); + String retrieverType = parser.currentName(); + retriever = parser.namedObject(RetrieverBuilder.class, retrieverType, context); + context.trackRetrieverUsage(retriever.getName()); + parser.nextToken(); + } else if (WEIGHT_FIELD.match(fieldName, parser.getDeprecationHandler())) { + if (weight != null) { + throw new ParsingException(parser.getTokenLocation(), "[weight] field can only be specified once"); + } + parser.nextToken(); + weight = parser.floatValue(); + } else { + throw new ParsingException( + parser.getTokenLocation(), + "unknown field [{}], expected [{}] or [{}]", + fieldName, + RETRIEVER_FIELD.getPreferredName(), + WEIGHT_FIELD.getPreferredName() + ); + } + } while (parser.nextToken() == XContentParser.Token.FIELD_NAME); + + if (retriever == null) { + throw new ParsingException(parser.getTokenLocation(), "retriever component must contain a retriever"); + } + + return new RRFRetrieverComponent(retriever, weight); + } else { + RetrieverBuilder retriever = parser.namedObject(RetrieverBuilder.class, firstFieldName, context); + context.trackRetrieverUsage(retriever.getName()); + if (parser.nextToken() != XContentParser.Token.END_OBJECT) { + throw new ParsingException(parser.getTokenLocation(), "unknown field [{}] after retriever", parser.currentName()); + } + return new RRFRetrieverComponent(retriever, DEFAULT_WEIGHT); + } + } +} diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java index add6f271b06ba..f518377c5c636 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java @@ -27,6 +27,7 @@ import java.util.ArrayList; import java.util.List; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -56,12 +57,15 @@ public static RRFRetrieverBuilder createRandomRRFRetrieverBuilder() { int retrieverCount = randomIntBetween(2, 50); List innerRetrievers = new ArrayList<>(retrieverCount); + float[] weights = new float[retrieverCount]; + int i = 0; while (retrieverCount > 0) { innerRetrievers.add(CompoundRetrieverBuilder.RetrieverSource.from(TestRetrieverBuilder.createRandomTestRetrieverBuilder())); + weights[i++] = randomFloat(); --retrieverCount; } - return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant); + return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant, weights); } @Override @@ -89,7 +93,7 @@ protected NamedXContentRegistry xContentRegistry() { new NamedXContentRegistry.Entry( RetrieverBuilder.class, TestRetrieverBuilder.TEST_SPEC.getName(), - (p, c) -> TestRetrieverBuilder.TEST_SPEC.getParser().fromXContent(p, (RetrieverParserContext) c), + (p, c) -> TestRetrieverBuilder.fromXContent(p, (RetrieverParserContext) c), TestRetrieverBuilder.TEST_SPEC.getName().getForRestApiVersion() ) ); @@ -103,6 +107,28 @@ protected NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry(entries); } + private void checkRRFRetrieverParsing(String restContent) throws IOException { + SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder(); + try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) { + SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true); + assertThat(source.retriever(), instanceOf(RRFRetrieverBuilder.class)); + RRFRetrieverBuilder parsed = (RRFRetrieverBuilder) source.retriever(); + assertThat(parsed.minScore(), equalTo(20f)); + assertThat(parsed.retrieverName(), equalTo("foo_rrf")); + try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) { + SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent( + parseSerialized, + true, + searchUsageHolder, + nf -> true + ); + assertThat(deserializedSource.retriever(), instanceOf(RRFRetrieverBuilder.class)); + RRFRetrieverBuilder deserialized = (RRFRetrieverBuilder) source.retriever(); + assertThat(parsed, equalTo(deserialized)); + } + } + } + public void testRRFRetrieverParsing() throws IOException { String restContent = """ { @@ -130,24 +156,226 @@ public void testRRFRetrieverParsing() throws IOException { } } """; + checkRRFRetrieverParsing(restContent); + } + + public void testRRFRetrieverParsingWithWeights() throws IOException { + String restContent = """ + { + "retriever": { + "rrf": { + "retrievers": [ + { + "retriever": { + "test": { + "value": "first" + } + }, + "weight": 2.0 + }, + { + "retriever": { + "test": { + "value": "second" + } + }, + "weight": 0.5 + } + ], + "rank_window_size": 100, + "rank_constant": 10, + "min_score": 20.0, + "_name": "foo_rrf" + } + } + } + """; + checkRRFRetrieverParsing(restContent); + } + + public void testRRFRetrieverParsingWithMixedWeights() throws IOException { + String restContent = """ + { + "retriever": { + "rrf": { + "retrievers": [ + { + "test": { + "value": "no_weight" + } + }, + { + "retriever": { + "test": { + "value": "with_weight" + } + }, + "weight": 1.5 + } + ], + "rank_window_size": 100, + "rank_constant": 10, + "min_score": 20.0, + "_name": "foo_rrf" + } + } + } + """; + checkRRFRetrieverParsing(restContent); + } + + public void testRRFRetrieverParsingWithDefaultWeights() throws IOException { + String restContent = """ + { + "retriever": { + "rrf": { + "retrievers": [ + { + "test": { + "value": "first" + } + }, + { + "test": { + "value": "second" + } + } + ], + "rank_window_size": 100, + "rank_constant": 10, + "min_score": 20.0, + "_name": "foo_rrf" + } + } + } + """; + checkRRFRetrieverParsing(restContent); + } + + public void testRRFRetrieverComponentErrorCases() throws IOException { + // Test case 1: Multiple retrievers in same component + String multipleRetrieversContent = """ + { + "retriever": { + "rrf": { + "retrievers": [ + { + "retriever": { "test": { "value": "first" } }, + "standard": { "query": { "match_all": {} } } + } + ], + "rank_window_size": 100, + "rank_constant": 10, + "min_score": 20.0, + "_name": "foo_rrf" + } + } + } + """; + + expectParsingException(multipleRetrieversContent, "unknown field [standard], expected [retriever] or [weight]"); + + // Test case 2: Weight without retriever + String weightOnlyContent = """ + { + "retriever": { + "rrf": { + "retrievers": [ + { + "weight": 2.0 + } + ], + "rank_window_size": 100, + "rank_constant": 10, + "min_score": 20.0, + "_name": "foo_rrf" + } + } + } + """; + + expectParsingException(weightOnlyContent, "retriever component must contain a retriever"); + + // Test case 3: Empty retriever component + String emptyComponentContent = """ + { + "retriever": { + "rrf": { + "retrievers": [ + {} + ], + "rank_window_size": 100, + "rank_constant": 10, + "min_score": 20.0, + "_name": "foo_rrf" + } + } + } + """; + + expectParsingException(emptyComponentContent, "retriever component must contain a retriever"); + + // Test case 4: Negative weight + String negativeWeightContent = """ + { + "retriever": { + "rrf": { + "retrievers": [ + { + "retriever": { "test": { "value": "test" } }, + "weight": -1.0 + } + ], + "rank_window_size": 100, + "rank_constant": 10, + "min_score": 20.0, + "_name": "foo_rrf" + } + } + } + """; + + expectParsingException(negativeWeightContent, "[weight] must be non-negative"); + + // Test case 5: Retriever as non-object + String retrieverAsStringContent = """ + { + "retriever": { + "rrf": { + "retrievers": [ + { + "retriever": "not_an_object" + } + ], + "rank_window_size": 100, + "rank_constant": 10, + "min_score": 20.0, + "_name": "foo_rrf" + } + } + } + """; + + expectParsingException(retrieverAsStringContent, "retriever must be an object"); + } + + private void expectParsingException(String restContent, String expectedMessageFragment) throws IOException { SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder(); try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) { - SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true); - assertThat(source.retriever(), instanceOf(RRFRetrieverBuilder.class)); - RRFRetrieverBuilder parsed = (RRFRetrieverBuilder) source.retriever(); - assertThat(parsed.minScore(), equalTo(20f)); - assertThat(parsed.retrieverName(), equalTo("foo_rrf")); - try (XContentParser parseSerialized = createParser(JsonXContent.jsonXContent, Strings.toString(source))) { - SearchSourceBuilder deserializedSource = new SearchSourceBuilder().parseXContent( - parseSerialized, - true, - searchUsageHolder, - nf -> true - ); - assertThat(deserializedSource.retriever(), instanceOf(RRFRetrieverBuilder.class)); - RRFRetrieverBuilder deserialized = (RRFRetrieverBuilder) source.retriever(); - assertThat(parsed, equalTo(deserialized)); + Exception exception = expectThrows(Exception.class, () -> { + new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true); + }); + + String message = exception.getMessage(); + if (exception.getCause() != null) { + message = exception.getCause().getMessage(); } + + assertThat( + "Expected error message to contain: " + expectedMessageFragment + ", but got: " + message, + message, + containsString(expectedMessageFragment) + ); } } } diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java index 5e8d46cb5b27a..7885ac9df2aa8 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java @@ -39,8 +39,10 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.BiConsumer; import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE; +import static org.hamcrest.Matchers.instanceOf; /** Tests for the rrf retriever. */ public class RRFRetrieverBuilderTests extends ESTestCase { @@ -84,6 +86,61 @@ public void testRetrieverExtractionErrors() throws IOException { } } + public void testRRFRetrieverParsingSyntax() throws IOException { + BiConsumer testCase = (json, expectedWeights) -> { + try (XContentParser parser = createParser(JsonXContent.jsonXContent, json)) { + SearchSourceBuilder ssb = new SearchSourceBuilder().parseXContent(parser, true, nf -> true); + assertThat(ssb.retriever(), instanceOf(RRFRetrieverBuilder.class)); + RRFRetrieverBuilder rrf = (RRFRetrieverBuilder) ssb.retriever(); + assertArrayEquals(expectedWeights, rrf.weights(), 0.001f); + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + + String legacyJson = """ + { + "retriever": { + "rrf_nl": { + "retrievers": [ + { "standard": { "query": { "match_all": {} } } }, + { "standard": { "query": { "match_all": {} } } } + ] + } + } + } + """; + testCase.accept(legacyJson, new float[] { 1.0f, 1.0f }); + + String weightedJson = """ + { + "retriever": { + "rrf_nl": { + "retrievers": [ + { "retriever": { "standard": { "query": { "match_all": {} } } }, "weight": 2.5 }, + { "retriever": { "standard": { "query": { "match_all": {} } } }, "weight": 0.5 } + ] + } + } + } + """; + testCase.accept(weightedJson, new float[] { 2.5f, 0.5f }); + + String mixedJson = """ + { + "retriever": { + "rrf_nl": { + "retrievers": [ + { "standard": { "query": { "match_all": {} } } }, + { "retriever": { "standard": { "query": { "match_all": {} } } }, "weight": 0.6 } + ] + } + } + } + """; + testCase.accept(mixedJson, new float[] { 1.0f, 0.6f }); + } + public void testMultiFieldsParamsRewrite() { final String indexName = "test-index"; final List testInferenceFields = List.of("semantic_field_1", "semantic_field_2"); @@ -103,7 +160,8 @@ public void testMultiFieldsParamsRewrite() { List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"), "foo", DEFAULT_RANK_WINDOW_SIZE, - RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT, + new float[0] ); assertMultiFieldsParamsRewrite( rrfRetrieverBuilder, @@ -119,7 +177,8 @@ public void testMultiFieldsParamsRewrite() { List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"), "foo2", DEFAULT_RANK_WINDOW_SIZE * 2, - RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT / 2 + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT / 2, + new float[0] ); assertMultiFieldsParamsRewrite( rrfRetrieverBuilder, @@ -135,7 +194,8 @@ public void testMultiFieldsParamsRewrite() { List.of("field_*", "*_field_1"), "bar", DEFAULT_RANK_WINDOW_SIZE, - RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT, + new float[0] ); assertMultiFieldsParamsRewrite( rrfRetrieverBuilder, @@ -151,7 +211,8 @@ public void testMultiFieldsParamsRewrite() { List.of("*"), "baz", DEFAULT_RANK_WINDOW_SIZE, - RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT, + new float[0] ); assertMultiFieldsParamsRewrite( rrfRetrieverBuilder, @@ -182,7 +243,8 @@ public void testSearchRemoteIndex() { null, "foo", DEFAULT_RANK_WINDOW_SIZE, - RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT + RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT, + new float[0] ); IllegalArgumentException iae = expectThrows( diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/320_rrf_weighted_retriever.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/320_rrf_weighted_retriever.yml new file mode 100644 index 0000000000000..f923345c78a96 --- /dev/null +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/320_rrf_weighted_retriever.yml @@ -0,0 +1,139 @@ +setup: + - requires: + cluster_features: [ "rrf_retriever.weighted_support" ] + reason: "RRF retriever Weighted support" + test_runner_features: [ "contains", "close_to" ] + - do: + indices.create: + index: restaurants + body: + mappings: + properties: + name: { type: keyword } + description: { type: text } + city: { type: keyword } + region: { type: keyword } + vector: { type: dense_vector, dims: 3 } + - do: + index: + index: restaurants + id: "1" + body: { name: "Pizza Palace", description: "Best pizza in town", city: "Vienna", region: "Austria", vector: [10,22,77] } + - do: + index: + index: restaurants + id: "2" + body: { name: "Burger House", description: "Juicy burgers", city: "Graz", region: "Austria", vector: [15,25,70] } + - do: + index: + index: restaurants + id: "3" + body: { name: "Sushi World", description: "Fresh sushi", city: "Linz", region: "Austria", vector: [11,24,75] } + - do: + indices.refresh: { index: restaurants } + +--- +"Weighted RRF retriever returns correct results": + - do: + search: + index: restaurants + body: + retriever: + rrf: + retrievers: + - retriever: + standard: + query: + multi_match: + query: "Austria" + fields: ["city", "region"] + weight: 0.3 + - retriever: + standard: + query: + match: + description: "pizza" + weight: 0.7 + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "1" } + +--- +"Weighted RRF retriever allows optional weight field": + - do: + search: + index: restaurants + body: + retriever: + rrf: + retrievers: + - standard: + query: + multi_match: + query: "Austria" + fields: ["city", "region"] + - retriever: + standard: + query: + match: + description: "pizza" + weight: 0.7 + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "1" } + +--- +"Weighted RRF retriever changes result order": + - do: + search: + index: restaurants + body: + retriever: + rrf: + retrievers: + - retriever: + standard: + query: + match: + description: "pizza" + weight: 0.1 + - retriever: + standard: + query: + match: + description: "burgers" + weight: 0.9 + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "2" } + - match: { hits.hits.1._id: "1" } + # Document 2: matches "burgers" with weight 0.9 + # RRF score = 1/(60+1) * 0.9 = 0.01475 + - close_to: {hits.hits.0._score: {value: 0.01475, error: 0.0001}} + # Document 1: matches "pizza" with weight 0.1 + # RRF score = 1/(60+1) * 0.1 = 0.00164 + - close_to: {hits.hits.1._score: {value: 0.00164, error: 0.0001}} + +--- +"Weighted RRF retriever errors on negative weight": + - do: + catch: bad_request + search: + index: restaurants + body: + retriever: + rrf: + retrievers: + - retriever: + standard: + query: + multi_match: + query: "Austria" + fields: ["city", "region"] + weight: -0.5 + - retriever: + standard: + query: + match: + description: "pizza" + weight: 0.7 + - match: { error.type: "x_content_parse_exception" } + - contains: { error.caused_by.reason: "[weight] must be non-negative, found [-0.5]" } +