Skip to content

Commit fc0ea64

Browse files
mridula-s109elasticsearchmachineioanatialeemthompo
authored
Implement support for weighted rrf (#130658)
* RRFRetrieverComponent added: * Modified parser, toXcontent and included component in the RetrieverBuilder * [CI] Auto commit changes from spotless * Resolved merge conflicts * Fixed compile issues in tests * [CI] Auto commit changes from spotless * trying to resolve parse errros * wip * Modified builder * [CI] Auto commit changes from spotless * Removed unnecessary code * Fixed import * Enhanced tests * Fixed the failing tests * Yaml tests were added * Added cluster features to it * Fixed spotless * Update docs/changelog/130658.yaml * Fixed the relaxed constraints * Resolving issues * Resolved PR comments * removed simplified rrf * changed the test file back to its original state * Resolved comments to have ahelper method and the test case to use it * made parsing robust * IT test reverted * Replaced the declareString array parser * Enforced weights as nonnull * Fixed the weights null * Empty weight shouldnt be serialised * [CI] Auto commit changes from spotless * removed the hard coding * Cleanup and optimised the code flow * Fixed the comments * [CI] Auto commit changes from spotless * optimised test * Added additional test * addressed the commentS * Update docs/changelog/130658.yaml Co-authored-by: Liam Thompson <[email protected]> * Explicit check for retriever object * Resolved PR comments * Fixed the error message --------- Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: Ioana Tagirta <[email protected]> Co-authored-by: Liam Thompson <[email protected]>
1 parent 7e2ff36 commit fc0ea64

File tree

7 files changed

+663
-69
lines changed

7 files changed

+663
-69
lines changed

docs/changelog/130658.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 130658
2+
summary: Add support for weighted RRF in retrievers
3+
area: Relevance
4+
type: enhancement
5+
issues: []

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ public Set<NodeFeature> getTestFeatures() {
3636
LINEAR_RETRIEVER_L2_NORM,
3737
LINEAR_RETRIEVER_MINSCORE_FIX,
3838
LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
39-
RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT
39+
RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
40+
RRFRetrieverBuilder.WEIGHTED_SUPPORT
4041
);
4142
}
4243
}

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

Lines changed: 81 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.search.rank.RankBuilder;
2121
import org.elasticsearch.search.rank.RankDoc;
2222
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
23+
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder.RetrieverSource;
2324
import org.elasticsearch.search.retriever.RetrieverBuilder;
2425
import org.elasticsearch.search.retriever.RetrieverParserContext;
2526
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
@@ -37,7 +38,7 @@
3738
import java.util.Map;
3839
import java.util.Objects;
3940

40-
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
41+
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverComponent.DEFAULT_WEIGHT;
4142

4243
/**
4344
* An rrf retriever is used to represent an rrf rank element, but
@@ -48,6 +49,7 @@
4849
*/
4950
public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetrieverBuilder> {
5051
public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support");
52+
public static final NodeFeature WEIGHTED_SUPPORT = new NodeFeature("rrf_retriever.weighted_support");
5153

5254
public static final String NAME = "rrf";
5355

@@ -57,37 +59,38 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
5759
public static final ParseField QUERY_FIELD = new ParseField("query");
5860

5961
public static final int DEFAULT_RANK_CONSTANT = 60;
62+
63+
private final float[] weights;
64+
6065
@SuppressWarnings("unchecked")
6166
static final ConstructingObjectParser<RRFRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
6267
NAME,
6368
false,
6469
args -> {
65-
List<RetrieverBuilder> childRetrievers = (List<RetrieverBuilder>) args[0];
70+
List<RRFRetrieverComponent> retrieverComponents = args[0] == null ? List.of() : (List<RRFRetrieverComponent>) args[0];
6671
List<String> fields = (List<String>) args[1];
6772
String query = (String) args[2];
6873
int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
6974
int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4];
7075

71-
List<RetrieverSource> innerRetrievers = childRetrievers != null
72-
? childRetrievers.stream().map(RetrieverSource::from).toList()
73-
: List.of();
74-
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
76+
int n = retrieverComponents.size();
77+
List<RetrieverSource> innerRetrievers = new ArrayList<>(n);
78+
float[] weights = new float[n];
79+
for (int i = 0; i < n; i++) {
80+
RRFRetrieverComponent component = retrieverComponents.get(i);
81+
innerRetrievers.add(RetrieverSource.from(component.retriever()));
82+
weights[i] = component.weight();
83+
}
84+
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant, weights);
7585
}
7686
);
7787

7888
static {
79-
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> {
80-
p.nextToken();
81-
String name = p.currentName();
82-
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
83-
c.trackRetrieverUsage(retrieverBuilder.getName());
84-
p.nextToken();
85-
return retrieverBuilder;
86-
}, RETRIEVERS_FIELD);
87-
PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD);
88-
PARSER.declareString(optionalConstructorArg(), QUERY_FIELD);
89-
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
90-
PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD);
89+
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), RRFRetrieverComponent::fromXContent, RETRIEVERS_FIELD);
90+
PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), FIELDS_FIELD);
91+
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), QUERY_FIELD);
92+
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
93+
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_CONSTANT_FIELD);
9194
RetrieverBuilder.declareBaseParserFields(PARSER);
9295
}
9396

@@ -103,27 +106,46 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
103106
private final int rankConstant;
104107

105108
public RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
106-
this(childRetrievers, null, null, rankWindowSize, rankConstant);
109+
this(childRetrievers, null, null, rankWindowSize, rankConstant, createDefaultWeights(childRetrievers));
110+
}
111+
112+
private static float[] createDefaultWeights(List<?> retrievers) {
113+
int size = retrievers == null ? 0 : retrievers.size();
114+
float[] defaultWeights = new float[size];
115+
Arrays.fill(defaultWeights, DEFAULT_WEIGHT);
116+
return defaultWeights;
107117
}
108118

109119
public RRFRetrieverBuilder(
110120
List<RetrieverSource> childRetrievers,
111121
List<String> fields,
112122
String query,
113123
int rankWindowSize,
114-
int rankConstant
124+
int rankConstant,
125+
float[] weights
115126
) {
116127
// Use a mutable list for childRetrievers so that we can use addChild
117128
super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize);
118129
this.fields = fields == null ? null : List.copyOf(fields);
119130
this.query = query;
120131
this.rankConstant = rankConstant;
132+
Objects.requireNonNull(weights, "weights must not be null");
133+
if (weights.length != innerRetrievers.size()) {
134+
throw new IllegalArgumentException(
135+
"weights array length [" + weights.length + "] must match retrievers count [" + innerRetrievers.size() + "]"
136+
);
137+
}
138+
this.weights = weights;
121139
}
122140

123141
public int rankConstant() {
124142
return rankConstant;
125143
}
126144

145+
public float[] weights() {
146+
return weights;
147+
}
148+
127149
@Override
128150
public String getName() {
129151
return NAME;
@@ -137,6 +159,7 @@ public ActionRequestValidationException validate(
137159
boolean allowPartialSearchResults
138160
) {
139161
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
162+
140163
return MultiFieldsInnerRetrieverUtils.validateParams(
141164
innerRetrievers,
142165
fields,
@@ -151,7 +174,14 @@ public ActionRequestValidationException validate(
151174

152175
@Override
153176
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
154-
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant);
177+
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(
178+
newRetrievers,
179+
this.fields,
180+
this.query,
181+
this.rankWindowSize,
182+
this.rankConstant,
183+
this.weights
184+
);
155185
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
156186
clone.retrieverName = retrieverName;
157187
return clone;
@@ -183,7 +213,7 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults
183213

184214
// calculate the current rrf score for this document
185215
// later used to sort and covert to a rank
186-
value.score += 1.0f / (rankConstant + frank);
216+
value.score += this.weights[findex] * (1.0f / (rankConstant + frank));
187217

188218
if (explain && value.positions != null && value.scores != null) {
189219
// record the position for each query
@@ -238,10 +268,14 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
238268
query,
239269
localIndicesMetadata.values(),
240270
r -> {
241-
List<RetrieverSource> retrievers = r.stream()
242-
.map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource)
243-
.toList();
244-
return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant);
271+
List<RetrieverSource> retrievers = new ArrayList<>(r.size());
272+
float[] weights = new float[r.size()];
273+
for (int i = 0; i < r.size(); i++) {
274+
var retriever = r.get(i);
275+
retrievers.add(retriever.retrieverSource());
276+
weights[i] = retriever.weight();
277+
}
278+
return new RRFRetrieverBuilder(retrievers, null, null, rankWindowSize, rankConstant, weights);
245279
},
246280
w -> {
247281
if (w != 1.0f) {
@@ -255,7 +289,8 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
255289
if (fieldsInnerRetrievers.isEmpty() == false) {
256290
// TODO: This is a incomplete solution as it does not address other incomplete copy issues
257291
// (such as dropping the retriever name and min score)
258-
rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant);
292+
float[] weights = createDefaultWeights(fieldsInnerRetrievers);
293+
rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, null, null, rankWindowSize, rankConstant, weights);
259294
rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
260295
} else {
261296
// Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
@@ -266,29 +301,13 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
266301
return rewritten;
267302
}
268303

269-
// ---- FOR TESTING XCONTENT PARSING ----
270-
271-
@Override
272-
public boolean doEquals(Object o) {
273-
RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
274-
return super.doEquals(o)
275-
&& Objects.equals(fields, that.fields)
276-
&& Objects.equals(query, that.query)
277-
&& rankConstant == that.rankConstant;
278-
}
279-
280-
@Override
281-
public int doHashCode() {
282-
return Objects.hash(super.doHashCode(), fields, query, rankConstant);
283-
}
284-
285304
@Override
286305
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
287306
if (innerRetrievers.isEmpty() == false) {
288307
builder.startArray(RETRIEVERS_FIELD.getPreferredName());
289-
290-
for (var entry : innerRetrievers) {
291-
entry.retriever().toXContent(builder, params);
308+
for (int i = 0; i < innerRetrievers.size(); i++) {
309+
RRFRetrieverComponent component = new RRFRetrieverComponent(innerRetrievers.get(i).retriever(), weights[i]);
310+
component.toXContent(builder, params);
292311
}
293312
builder.endArray();
294313
}
@@ -307,4 +326,20 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
307326
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
308327
builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant);
309328
}
329+
330+
// ---- FOR TESTING XCONTENT PARSING ----
331+
@Override
332+
public boolean doEquals(Object o) {
333+
RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
334+
return super.doEquals(o)
335+
&& Objects.equals(fields, that.fields)
336+
&& Objects.equals(query, that.query)
337+
&& rankConstant == that.rankConstant
338+
&& Arrays.equals(weights, that.weights);
339+
}
340+
341+
@Override
342+
public int doHashCode() {
343+
return Objects.hash(super.doHashCode(), fields, query, rankConstant, Arrays.hashCode(weights));
344+
}
310345
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.rank.rrf;
9+
10+
import org.elasticsearch.common.ParsingException;
11+
import org.elasticsearch.search.retriever.RetrieverBuilder;
12+
import org.elasticsearch.search.retriever.RetrieverParserContext;
13+
import org.elasticsearch.xcontent.ParseField;
14+
import org.elasticsearch.xcontent.ToXContentObject;
15+
import org.elasticsearch.xcontent.XContentBuilder;
16+
import org.elasticsearch.xcontent.XContentParser;
17+
18+
import java.io.IOException;
19+
import java.util.Objects;
20+
21+
public class RRFRetrieverComponent implements ToXContentObject {
22+
23+
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
24+
public static final ParseField WEIGHT_FIELD = new ParseField("weight");
25+
static final float DEFAULT_WEIGHT = 1f;
26+
27+
final RetrieverBuilder retriever;
28+
final float weight;
29+
30+
public RRFRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight) {
31+
this.retriever = Objects.requireNonNull(retrieverBuilder, "retrieverBuilder must not be null");
32+
this.weight = weight == null ? DEFAULT_WEIGHT : weight;
33+
if (this.weight < 0) {
34+
throw new IllegalArgumentException("[weight] must be non-negative, found [" + this.weight + "]");
35+
}
36+
}
37+
38+
public RetrieverBuilder retriever() {
39+
return retriever;
40+
}
41+
42+
public float weight() {
43+
return weight;
44+
}
45+
46+
@Override
47+
public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Params params) throws IOException {
48+
builder.startObject();
49+
builder.field(RETRIEVER_FIELD.getPreferredName(), retriever);
50+
builder.field(WEIGHT_FIELD.getPreferredName(), weight);
51+
builder.endObject();
52+
return builder;
53+
}
54+
55+
public static RRFRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
56+
if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
57+
throw new ParsingException(parser.getTokenLocation(), "expected object but found [{}]", parser.currentToken());
58+
}
59+
60+
// Peek at the first field to determine the format
61+
XContentParser.Token token = parser.nextToken();
62+
if (token == XContentParser.Token.END_OBJECT) {
63+
throw new ParsingException(parser.getTokenLocation(), "retriever component must contain a retriever");
64+
}
65+
if (token != XContentParser.Token.FIELD_NAME) {
66+
throw new ParsingException(parser.getTokenLocation(), "expected field name but found [{}]", token);
67+
}
68+
69+
String firstFieldName = parser.currentName();
70+
71+
// Check if this is a structured component (starts with "retriever" or "weight")
72+
if (RETRIEVER_FIELD.match(firstFieldName, parser.getDeprecationHandler())
73+
|| WEIGHT_FIELD.match(firstFieldName, parser.getDeprecationHandler())) {
74+
// This is a structured component - parse manually
75+
RetrieverBuilder retriever = null;
76+
Float weight = null;
77+
78+
do {
79+
String fieldName = parser.currentName();
80+
if (RETRIEVER_FIELD.match(fieldName, parser.getDeprecationHandler())) {
81+
if (retriever != null) {
82+
throw new ParsingException(parser.getTokenLocation(), "only one retriever can be specified");
83+
}
84+
parser.nextToken();
85+
if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
86+
throw new ParsingException(parser.getTokenLocation(), "retriever must be an object");
87+
}
88+
parser.nextToken();
89+
String retrieverType = parser.currentName();
90+
retriever = parser.namedObject(RetrieverBuilder.class, retrieverType, context);
91+
context.trackRetrieverUsage(retriever.getName());
92+
parser.nextToken();
93+
} else if (WEIGHT_FIELD.match(fieldName, parser.getDeprecationHandler())) {
94+
if (weight != null) {
95+
throw new ParsingException(parser.getTokenLocation(), "[weight] field can only be specified once");
96+
}
97+
parser.nextToken();
98+
weight = parser.floatValue();
99+
} else {
100+
throw new ParsingException(
101+
parser.getTokenLocation(),
102+
"unknown field [{}], expected [{}] or [{}]",
103+
fieldName,
104+
RETRIEVER_FIELD.getPreferredName(),
105+
WEIGHT_FIELD.getPreferredName()
106+
);
107+
}
108+
} while (parser.nextToken() == XContentParser.Token.FIELD_NAME);
109+
110+
if (retriever == null) {
111+
throw new ParsingException(parser.getTokenLocation(), "retriever component must contain a retriever");
112+
}
113+
114+
return new RRFRetrieverComponent(retriever, weight);
115+
} else {
116+
RetrieverBuilder retriever = parser.namedObject(RetrieverBuilder.class, firstFieldName, context);
117+
context.trackRetrieverUsage(retriever.getName());
118+
if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
119+
throw new ParsingException(parser.getTokenLocation(), "unknown field [{}] after retriever", parser.currentName());
120+
}
121+
return new RRFRetrieverComponent(retriever, DEFAULT_WEIGHT);
122+
}
123+
}
124+
}

0 commit comments

Comments
 (0)