Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
8be20f0
RRFRetrieverComponent added:
mridula-s109 Jul 4, 2025
a8f6487
Modified parser, toXcontent and included component in the RetrieverBu…
mridula-s109 Jul 4, 2025
e07c38d
[CI] Auto commit changes from spotless
Jul 4, 2025
33d3da4
Resolved merge conflicts
mridula-s109 Jul 15, 2025
5fb5568
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 15, 2025
3ba149c
Fixed compile issues in tests
mridula-s109 Jul 15, 2025
d5749f6
[CI] Auto commit changes from spotless
Jul 15, 2025
7614936
trying to resolve parse errros
mridula-s109 Jul 16, 2025
a5d9e34
wip
ioanatia Jul 17, 2025
0640099
Modified builder
mridula-s109 Jul 17, 2025
cec23c2
[CI] Auto commit changes from spotless
Jul 17, 2025
6da9e15
Removed unnecessary code
mridula-s109 Jul 18, 2025
51b350e
Fixed import
mridula-s109 Jul 18, 2025
4050a3a
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 18, 2025
ea664eb
Enhanced tests
mridula-s109 Jul 18, 2025
98e72be
Fixed the failing tests
mridula-s109 Jul 21, 2025
7de8c7a
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 21, 2025
c778dd2
Yaml tests were added
mridula-s109 Jul 22, 2025
c7b331d
Added cluster features to it
mridula-s109 Jul 22, 2025
f543cbe
Fixed spotless
mridula-s109 Jul 22, 2025
75ab8d0
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 22, 2025
f5086c1
Update docs/changelog/130658.yaml
mridula-s109 Jul 22, 2025
fafb50f
Fixed the relaxed constraints
mridula-s109 Jul 23, 2025
e535864
Resolving issues
mridula-s109 Jul 23, 2025
78f8641
Resolved PR comments
mridula-s109 Jul 23, 2025
02647b1
removed simplified rrf
mridula-s109 Jul 23, 2025
2010f3a
changed the test file back to its original state
mridula-s109 Jul 24, 2025
7433023
Resolved comments to have ahelper method and the test case to use it
mridula-s109 Jul 24, 2025
a2bf4de
made parsing robust
mridula-s109 Jul 24, 2025
eebf577
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 24, 2025
0388abd
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 24, 2025
74ed8db
IT test reverted
mridula-s109 Jul 24, 2025
6d7e8ff
Replaced the declareString array parser
mridula-s109 Jul 25, 2025
f1e14ce
Enforced weights as nonnull
mridula-s109 Jul 25, 2025
fd30387
Fixed the weights null
mridula-s109 Jul 25, 2025
3a82a28
Empty weight shouldnt be serialised
mridula-s109 Jul 25, 2025
77c14d3
[CI] Auto commit changes from spotless
Jul 25, 2025
45ca068
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 25, 2025
2e121b0
removed the hard coding
mridula-s109 Jul 28, 2025
d66f5a6
Cleanup and optimised the code flow
mridula-s109 Jul 28, 2025
532e7df
Fixed the comments
mridula-s109 Jul 28, 2025
184330c
[CI] Auto commit changes from spotless
Jul 28, 2025
c43a075
optimised test
mridula-s109 Jul 28, 2025
e5f8079
Added additional test
mridula-s109 Jul 28, 2025
65cc528
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 28, 2025
4ac4940
addressed the commentS
mridula-s109 Jul 29, 2025
e6f22bc
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 29, 2025
3fc0d35
Update docs/changelog/130658.yaml
mridula-s109 Jul 31, 2025
5c364a0
Explicit check for retriever object
mridula-s109 Jul 31, 2025
978e182
Resolved PR comments
mridula-s109 Jul 31, 2025
032e946
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 31, 2025
1d807ee
Fixed the error message
mridula-s109 Jul 31, 2025
7f4c7cd
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 31, 2025
99f4ad2
Merge branch 'main' into SEARCH-1026-implement-support-for-weighted-rrf
mridula-s109 Jul 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/changelog/130658.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pr: 130658
summary: Implement support for weighted rrf
summary: Add support for weighted RRF in retrievers
area: Relevance
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant");
public static final ParseField FIELDS_FIELD = new ParseField("fields");
public static final ParseField QUERY_FIELD = new ParseField("query");
public static final ParseField WEIGHTS_FIELD = new ParseField("weights");

public static final int DEFAULT_RANK_CONSTANT = 60;

Expand Down Expand Up @@ -92,7 +91,6 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), QUERY_FIELD);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_CONSTANT_FIELD);
PARSER.declareFloatArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS_FIELD);
RetrieverBuilder.declareBaseParserFields(PARSER);
}

Expand Down Expand Up @@ -265,17 +263,18 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
);
}

List<RetrieverBuilder> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers(
List<RetrieverSource> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers(
fields,
query,
localIndicesMetadata.values(),
r -> {
List<RetrieverSource> retrievers = new ArrayList<>(r.size());
for (var retriever : r) {
float[] weights = new float[r.size()];
for (int i = 0; i < r.size(); i++) {
var retriever = r.get(i);
retrievers.add(retriever.retrieverSource());
weights[i] = retriever.weight();
}
float[] weights = new float[retrievers.size()];
Arrays.fill(weights, 1.0f);
return new RRFRetrieverBuilder(retrievers, null, null, rankWindowSize, rankConstant, weights);
},
w -> {
Expand All @@ -285,20 +284,13 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
);
}
}
);
).stream().map(RetrieverSource::from).toList();

if (fieldsInnerRetrievers.isEmpty() == false) {
// TODO: This is a incomplete solution as it does not address other incomplete copy issues
// (such as dropping the retriever name and min score)
int size = fieldsInnerRetrievers.size();
List<RetrieverSource> sources = new ArrayList<>(size);
float[] weights = new float[size];
Arrays.fill(weights, RRFRetrieverComponent.DEFAULT_WEIGHT);
for (int i = 0; i < size; i++) {
sources.add(RetrieverSource.from(fieldsInnerRetrievers.get(i)));
weights[i] = RRFRetrieverComponent.DEFAULT_WEIGHT;
}
rewritten = new RRFRetrieverBuilder(sources, null, null, rankWindowSize, rankConstant, weights);
float[] weights = createDefaultWeights(fieldsInnerRetrievers);
rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, null, null, rankWindowSize, rankConstant, weights);
rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
} else {
// Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
Expand Down Expand Up @@ -333,13 +325,6 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept

builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant);
if (weights.length > 0) {
builder.startArray(WEIGHTS_FIELD.getPreferredName());
for (float weight : weights) {
builder.value(weight);
}
builder.endArray();
}
}

// ---- FOR TESTING XCONTENT PARSING ----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
Expand All @@ -19,9 +18,6 @@
import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

public class RRFRetrieverComponent implements ToXContentObject {

public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
Expand Down Expand Up @@ -56,26 +52,6 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para
return builder;
}

@SuppressWarnings("unchecked")
static final ConstructingObjectParser<RRFRetrieverComponent, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
"rrf_component",
false,
(args, context) -> {
RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[0];
Float weight = (Float) args[1];
return new RRFRetrieverComponent(retrieverBuilder, weight);
}
);

static {
PARSER.declareNamedObject(constructorArg(), (p, c, n) -> {
RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c);
c.trackRetrieverUsage(innerRetriever.getName());
return innerRetriever;
}, RETRIEVER_FIELD);
PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD);
}

public static RRFRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
throw new ParsingException(parser.getTokenLocation(), "expected object but found [{}]", parser.currentToken());
Expand Down Expand Up @@ -106,6 +82,9 @@ public static RRFRetrieverComponent fromXContent(XContentParser parser, Retrieve
throw new ParsingException(parser.getTokenLocation(), "only one retriever can be specified");
}
parser.nextToken();
if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
throw new ParsingException(parser.getTokenLocation(), "retriever must be an object");
}
parser.nextToken();
String retrieverType = parser.currentName();
retriever = parser.namedObject(RetrieverBuilder.class, retrieverType, context);
Expand All @@ -118,9 +97,6 @@ public static RRFRetrieverComponent fromXContent(XContentParser parser, Retrieve
parser.nextToken();
weight = parser.floatValue();
} else {
if (retriever != null) {
throw new ParsingException(parser.getTokenLocation(), "only one retriever can be specified");
}
throw new ParsingException(
parser.getTokenLocation(),
"unknown field [{}], expected [{}] or [{}]",
Expand All @@ -139,7 +115,8 @@ public static RRFRetrieverComponent fromXContent(XContentParser parser, Retrieve
} else {
RetrieverBuilder retriever = parser.namedObject(RetrieverBuilder.class, firstFieldName, context);
context.trackRetrieverUsage(retriever.getName());
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
throw new ParsingException(parser.getTokenLocation(), "unknown field [{}] after retriever", parser.currentName());
}
return new RRFRetrieverComponent(retriever, DEFAULT_WEIGHT);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,13 @@ public static RRFRetrieverBuilder createRandomRRFRetrieverBuilder() {

int retrieverCount = randomIntBetween(2, 50);
List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = new ArrayList<>(retrieverCount);
float[] weights = new float[retrieverCount];
int i = 0;
while (retrieverCount > 0) {
innerRetrievers.add(CompoundRetrieverBuilder.RetrieverSource.from(TestRetrieverBuilder.createRandomTestRetrieverBuilder()));
weights[i++] = randomFloat();
--retrieverCount;
}
float[] weights = new float[innerRetrievers.size()];
for (int i = 0; i < innerRetrievers.size(); i++) {
weights[i] = randomFloat();
}

return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant, weights);
}
Expand Down Expand Up @@ -274,7 +273,7 @@ public void testRRFRetrieverComponentErrorCases() throws IOException {
}
""";

expectParsingException(multipleRetrieversContent, "only one retriever can be specified");
expectParsingException(multipleRetrieversContent, "unknown field [standard], expected [retriever] or [weight]");

// Test case 2: Weight without retriever
String weightOnlyContent = """
Expand Down Expand Up @@ -336,7 +335,28 @@ public void testRRFRetrieverComponentErrorCases() throws IOException {
}
""";

expectParsingException(negativeWeightContent, "weight] must be non-negative");
expectParsingException(negativeWeightContent, "[weight] must be non-negative");

// Test case 5: Retriever as non-object
String retrieverAsStringContent = """
{
"retriever": {
"rrf": {
"retrievers": [
{
"retriever": "not_an_object"
}
],
"rank_window_size": 100,
"rank_constant": 10,
"min_score": 20.0,
"_name": "foo_rrf"
}
}
}
""";

expectParsingException(retrieverAsStringContent, "retriever must be an object");
}

private void expectParsingException(String restContent, String expectedMessageFragment) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;

import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -86,6 +87,17 @@ public void testRetrieverExtractionErrors() throws IOException {
}

public void testRRFRetrieverParsingSyntax() throws IOException {
BiConsumer<String, float[]> testCase = (json, expectedWeights) -> {
try (XContentParser parser = createParser(JsonXContent.jsonXContent, json)) {
SearchSourceBuilder ssb = new SearchSourceBuilder().parseXContent(parser, true, nf -> true);
assertThat(ssb.retriever(), instanceOf(RRFRetrieverBuilder.class));
RRFRetrieverBuilder rrf = (RRFRetrieverBuilder) ssb.retriever();
assertArrayEquals(expectedWeights, rrf.weights(), 0.001f);
} catch (IOException e) {
throw new RuntimeException(e);
}
};

String legacyJson = """
{
"retriever": {
Expand All @@ -98,12 +110,7 @@ public void testRRFRetrieverParsingSyntax() throws IOException {
}
}
""";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, legacyJson)) {
SearchSourceBuilder ssb = new SearchSourceBuilder().parseXContent(parser, true, nf -> true);
assertThat(ssb.retriever(), instanceOf(RRFRetrieverBuilder.class));
RRFRetrieverBuilder rrf = (RRFRetrieverBuilder) ssb.retriever();
assertArrayEquals(new float[] { 1.0f, 1.0f }, rrf.weights(), 0.001f);
}
testCase.accept(legacyJson, new float[] { 1.0f, 1.0f });

String weightedJson = """
{
Expand All @@ -117,12 +124,7 @@ public void testRRFRetrieverParsingSyntax() throws IOException {
}
}
""";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, weightedJson)) {
SearchSourceBuilder ssb = new SearchSourceBuilder().parseXContent(parser, true, nf -> true);
assertThat(ssb.retriever(), instanceOf(RRFRetrieverBuilder.class));
RRFRetrieverBuilder rrf = (RRFRetrieverBuilder) ssb.retriever();
assertArrayEquals(new float[] { 2.5f, 0.5f }, rrf.weights(), 0.001f);
}
testCase.accept(weightedJson, new float[] { 2.5f, 0.5f });

String mixedJson = """
{
Expand All @@ -136,12 +138,7 @@ public void testRRFRetrieverParsingSyntax() throws IOException {
}
}
""";
try (XContentParser parser = createParser(JsonXContent.jsonXContent, mixedJson)) {
SearchSourceBuilder ssb = new SearchSourceBuilder().parseXContent(parser, true, nf -> true);
assertThat(ssb.retriever(), instanceOf(RRFRetrieverBuilder.class));
RRFRetrieverBuilder rrf = (RRFRetrieverBuilder) ssb.retriever();
assertArrayEquals(new float[] { 1.0f, 0.6f }, rrf.weights(), 0.001f);
}
testCase.accept(mixedJson, new float[] { 1.0f, 0.6f });
}

public void testMultiFieldsParamsRewrite() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ setup:
- requires:
cluster_features: [ "rrf_retriever.weighted_support" ]
reason: "RRF retriever Weighted support"
test_runner_features: [ "contains" ]
test_runner_features: [ "contains", "close_to" ]
- do:
indices.create:
index: restaurants
Expand Down Expand Up @@ -58,7 +58,7 @@ setup:
- match: { hits.hits.0._id: "1" }

---
"Weighted RRF retriever defaults to weight 1":
"Weighted RRF retriever allows optional weight field":
- do:
search:
index: restaurants
Expand All @@ -80,6 +80,37 @@ setup:
- match: { hits.total.value: 3 }
- match: { hits.hits.0._id: "1" }

---
"Weighted RRF retriever changes result order":
- do:
search:
index: restaurants
body:
retriever:
rrf:
retrievers:
- retriever:
standard:
query:
match:
description: "pizza"
weight: 0.1
- retriever:
standard:
query:
match:
description: "burgers"
weight: 0.9
- match: { hits.total.value: 2 }
- match: { hits.hits.0._id: "2" }
- match: { hits.hits.1._id: "1" }
# Document 2: matches "burgers" with weight 0.9
# RRF score = 1/(60+1) * 0.9 = 0.01475
- close_to: {hits.hits.0._score: {value: 0.01475, error: 0.0001}}
# Document 1: matches "pizza" with weight 0.1
# RRF score = 1/(60+1) * 0.1 = 0.00164
- close_to: {hits.hits.1._score: {value: 0.00164, error: 0.0001}}

---
"Weighted RRF retriever errors on negative weight":
- do:
Expand All @@ -104,4 +135,5 @@ setup:
description: "pizza"
weight: 0.7
- match: { error.type: "x_content_parse_exception" }
- contains: { error.reason: "failed to parse field" }
- contains: { error.caused_by.reason: "[weight] must be non-negative, found [-0.5]" }