Skip to content

Commit 636da86

Browse files
authored
Simplified RRF Retriever (#129659)
1 parent f430a6c commit 636da86

File tree

8 files changed

+816
-41
lines changed

8 files changed

+816
-41
lines changed

docs/changelog/129659.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129659
2+
summary: Simplified RRF Retriever
3+
area: Search
4+
type: enhancement
5+
issues: []

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.features.FeatureSpecification;
1111
import org.elasticsearch.features.NodeFeature;
1212
import org.elasticsearch.xpack.rank.linear.LinearRetrieverBuilder;
13+
import org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder;
1314

1415
import java.util.Set;
1516

@@ -34,7 +35,8 @@ public Set<NodeFeature> getTestFeatures() {
3435
LINEAR_RETRIEVER_MINMAX_SINGLE_DOC_FIX,
3536
LINEAR_RETRIEVER_L2_NORM,
3637
LINEAR_RETRIEVER_MINSCORE_FIX,
37-
LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT
38+
LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
39+
RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT
3840
);
3941
}
4042
}

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

Lines changed: 131 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,27 @@
88
package org.elasticsearch.xpack.rank.rrf;
99

1010
import org.apache.lucene.search.ScoreDoc;
11+
import org.elasticsearch.action.ActionRequestValidationException;
12+
import org.elasticsearch.action.ResolvedIndices;
1113
import org.elasticsearch.common.util.Maps;
14+
import org.elasticsearch.features.NodeFeature;
15+
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
1216
import org.elasticsearch.index.query.QueryBuilder;
17+
import org.elasticsearch.index.query.QueryRewriteContext;
1318
import org.elasticsearch.license.LicenseUtils;
19+
import org.elasticsearch.search.builder.SearchSourceBuilder;
1420
import org.elasticsearch.search.rank.RankBuilder;
1521
import org.elasticsearch.search.rank.RankDoc;
1622
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
1723
import org.elasticsearch.search.retriever.RetrieverBuilder;
1824
import org.elasticsearch.search.retriever.RetrieverParserContext;
25+
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
1926
import org.elasticsearch.xcontent.ConstructingObjectParser;
2027
import org.elasticsearch.xcontent.ParseField;
2128
import org.elasticsearch.xcontent.XContentBuilder;
2229
import org.elasticsearch.xcontent.XContentParser;
2330
import org.elasticsearch.xpack.core.XPackPlugin;
31+
import org.elasticsearch.xpack.rank.MultiFieldsInnerRetrieverUtils;
2432

2533
import java.io.IOException;
2634
import java.util.ArrayList;
@@ -29,7 +37,6 @@
2937
import java.util.Map;
3038
import java.util.Objects;
3139

32-
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
3340
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
3441

3542
/**
@@ -40,11 +47,14 @@
4047
* formula.
4148
*/
4249
public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetrieverBuilder> {
50+
public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support");
4351

4452
public static final String NAME = "rrf";
4553

4654
public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
4755
public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant");
56+
public static final ParseField FIELDS_FIELD = new ParseField("fields");
57+
public static final ParseField QUERY_FIELD = new ParseField("query");
4858

4959
public static final int DEFAULT_RANK_CONSTANT = 60;
5060
@SuppressWarnings("unchecked")
@@ -53,22 +63,29 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
5363
false,
5464
args -> {
5565
List<RetrieverBuilder> childRetrievers = (List<RetrieverBuilder>) args[0];
56-
List<RetrieverSource> innerRetrievers = childRetrievers.stream().map(RetrieverSource::from).toList();
57-
int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1];
58-
int rankConstant = args[2] == null ? DEFAULT_RANK_CONSTANT : (int) args[2];
59-
return new RRFRetrieverBuilder(innerRetrievers, rankWindowSize, rankConstant);
66+
List<String> fields = (List<String>) args[1];
67+
String query = (String) args[2];
68+
int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
69+
int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4];
70+
71+
List<RetrieverSource> innerRetrievers = childRetrievers != null
72+
? childRetrievers.stream().map(RetrieverSource::from).toList()
73+
: List.of();
74+
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
6075
}
6176
);
6277

6378
static {
64-
PARSER.declareObjectArray(constructorArg(), (p, c) -> {
79+
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> {
6580
p.nextToken();
6681
String name = p.currentName();
6782
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
6883
c.trackRetrieverUsage(retrieverBuilder.getName());
6984
p.nextToken();
7085
return retrieverBuilder;
7186
}, RETRIEVERS_FIELD);
87+
PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD);
88+
PARSER.declareString(optionalConstructorArg(), QUERY_FIELD);
7289
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
7390
PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD);
7491
RetrieverBuilder.declareBaseParserFields(PARSER);
@@ -81,25 +98,60 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
8198
return PARSER.apply(parser, context);
8299
}
83100

101+
private final List<String> fields;
102+
private final String query;
84103
private final int rankConstant;
85104

86-
public RRFRetrieverBuilder(int rankWindowSize, int rankConstant) {
87-
this(new ArrayList<>(), rankWindowSize, rankConstant);
105+
public RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
106+
this(childRetrievers, null, null, rankWindowSize, rankConstant);
88107
}
89108

90-
RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
91-
super(childRetrievers, rankWindowSize);
109+
public RRFRetrieverBuilder(
110+
List<RetrieverSource> childRetrievers,
111+
List<String> fields,
112+
String query,
113+
int rankWindowSize,
114+
int rankConstant
115+
) {
116+
// Use a mutable list for childRetrievers so that we can use addChild
117+
super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize);
118+
this.fields = fields == null ? List.of() : List.copyOf(fields);
119+
this.query = query;
92120
this.rankConstant = rankConstant;
93121
}
94122

123+
public int rankConstant() {
124+
return rankConstant;
125+
}
126+
95127
@Override
96128
public String getName() {
97129
return NAME;
98130
}
99131

132+
@Override
133+
public ActionRequestValidationException validate(
134+
SearchSourceBuilder source,
135+
ActionRequestValidationException validationException,
136+
boolean isScroll,
137+
boolean allowPartialSearchResults
138+
) {
139+
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
140+
return MultiFieldsInnerRetrieverUtils.validateParams(
141+
innerRetrievers,
142+
fields,
143+
query,
144+
getName(),
145+
RETRIEVERS_FIELD.getPreferredName(),
146+
FIELDS_FIELD.getPreferredName(),
147+
QUERY_FIELD.getPreferredName(),
148+
validationException
149+
);
150+
}
151+
100152
@Override
101153
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
102-
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
154+
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant);
103155
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
104156
clone.retrieverName = retrieverName;
105157
return clone;
@@ -162,17 +214,72 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults
162214
return topResults;
163215
}
164216

217+
@Override
218+
protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
219+
RetrieverBuilder rewritten = this;
220+
221+
ResolvedIndices resolvedIndices = ctx.getResolvedIndices();
222+
if (resolvedIndices != null && query != null) {
223+
// TODO: Refactor duplicate code
224+
// Using the multi-fields query format
225+
var localIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata();
226+
if (localIndicesMetadata.size() > 1) {
227+
throw new IllegalArgumentException(
228+
"[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying multiple indices"
229+
);
230+
} else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
231+
throw new IllegalArgumentException(
232+
"[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying remote indices"
233+
);
234+
}
235+
236+
List<RetrieverSource> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers(
237+
fields,
238+
query,
239+
localIndicesMetadata.values(),
240+
r -> {
241+
List<RetrieverSource> retrievers = r.stream()
242+
.map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource)
243+
.toList();
244+
return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant);
245+
},
246+
w -> {
247+
if (w != 1.0f) {
248+
throw new IllegalArgumentException(
249+
"[" + NAME + "] does not support per-field weights in [" + FIELDS_FIELD.getPreferredName() + "]"
250+
);
251+
}
252+
}
253+
).stream().map(RetrieverSource::from).toList();
254+
255+
if (fieldsInnerRetrievers.isEmpty() == false) {
256+
// TODO: This is a incomplete solution as it does not address other incomplete copy issues
257+
// (such as dropping the retriever name and min score)
258+
rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant);
259+
rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
260+
} else {
261+
// Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
262+
rewritten = new StandardRetrieverBuilder(new MatchNoneQueryBuilder());
263+
}
264+
}
265+
266+
return rewritten;
267+
}
268+
165269
// ---- FOR TESTING XCONTENT PARSING ----
166270

167271
@Override
168272
public boolean doEquals(Object o) {
169273
RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
170-
return super.doEquals(o) && rankConstant == that.rankConstant;
274+
return super.doEquals(o)
275+
&& Objects.equals(fields, that.fields)
276+
&& Objects.equals(query, that.query)
277+
&& rankConstant == that.rankConstant;
171278
}
172279

173280
@Override
174281
public int doHashCode() {
175-
return Objects.hash(super.doHashCode(), rankConstant);
282+
return Objects.hash(super.doHashCode(), fields, query, rankConstant);
176283
}
177284

178285
@Override
@@ -186,6 +293,17 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
186293
builder.endArray();
187294
}
188295

296+
if (fields.isEmpty() == false) {
297+
builder.startArray(FIELDS_FIELD.getPreferredName());
298+
for (String field : fields) {
299+
builder.value(field);
300+
}
301+
builder.endArray();
302+
}
303+
if (query != null) {
304+
builder.field(QUERY_FIELD.getPreferredName(), query);
305+
}
306+
189307
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
190308
builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant);
191309
}

x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.action.search.SearchRequest;
1111
import org.elasticsearch.common.Strings;
1212
import org.elasticsearch.search.builder.SearchSourceBuilder;
13+
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
1314
import org.elasticsearch.search.retriever.RetrieverBuilder;
1415
import org.elasticsearch.search.retriever.RetrieverParserContext;
1516
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
@@ -45,13 +46,22 @@ public static RRFRetrieverBuilder createRandomRRFRetrieverBuilder() {
4546
if (randomBoolean()) {
4647
rankConstant = randomIntBetween(1, 1000000);
4748
}
48-
var ret = new RRFRetrieverBuilder(rankWindowSize, rankConstant);
49+
50+
List<String> fields = null;
51+
String query = null;
52+
if (randomBoolean()) {
53+
fields = randomList(1, 10, () -> randomAlphaOfLengthBetween(1, 10));
54+
query = randomAlphaOfLengthBetween(1, 10);
55+
}
56+
4957
int retrieverCount = randomIntBetween(2, 50);
58+
List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = new ArrayList<>(retrieverCount);
5059
while (retrieverCount > 0) {
51-
ret.addChild(TestRetrieverBuilder.createRandomTestRetrieverBuilder());
60+
innerRetrievers.add(CompoundRetrieverBuilder.RetrieverSource.from(TestRetrieverBuilder.createRandomTestRetrieverBuilder()));
5261
--retrieverCount;
5362
}
54-
return ret;
63+
64+
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
5565
}
5666

5767
@Override
@@ -94,28 +104,32 @@ protected NamedXContentRegistry xContentRegistry() {
94104
}
95105

96106
public void testRRFRetrieverParsing() throws IOException {
97-
String restContent = "{"
98-
+ " \"retriever\": {"
99-
+ " \"rrf\": {"
100-
+ " \"retrievers\": ["
101-
+ " {"
102-
+ " \"test\": {"
103-
+ " \"value\": \"foo\""
104-
+ " }"
105-
+ " },"
106-
+ " {"
107-
+ " \"test\": {"
108-
+ " \"value\": \"bar\""
109-
+ " }"
110-
+ " }"
111-
+ " ],"
112-
+ " \"rank_window_size\": 100,"
113-
+ " \"rank_constant\": 10,"
114-
+ " \"min_score\": 20.0,"
115-
+ " \"_name\": \"foo_rrf\""
116-
+ " }"
117-
+ " }"
118-
+ "}";
107+
String restContent = """
108+
{
109+
"retriever": {
110+
"rrf": {
111+
"retrievers": [
112+
{
113+
"test": {
114+
"value": "foo"
115+
}
116+
},
117+
{
118+
"test": {
119+
"value": "bar"
120+
}
121+
}
122+
],
123+
"fields": ["field1", "field2"],
124+
"query": "baz",
125+
"rank_window_size": 100,
126+
"rank_constant": 10,
127+
"min_score": 20.0,
128+
"_name": "foo_rrf"
129+
}
130+
}
131+
}
132+
""";
119133
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
120134
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
121135
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);

0 commit comments

Comments
 (0)