Skip to content

Commit f4e5d85

Browse files
authored
Simplified RRF Retriever (#129659) (#129869)
(cherry picked from commit 636da86) # Conflicts: # x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java # x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java
1 parent 5a17180 commit f4e5d85

File tree

8 files changed

+814
-41
lines changed

8 files changed

+814
-41
lines changed

docs/changelog/129659.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129659
2+
summary: Simplified RRF Retriever
3+
area: Search
4+
type: enhancement
5+
issues: []

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ public Set<NodeFeature> getTestFeatures() {
3939
LINEAR_RETRIEVER_MINMAX_SINGLE_DOC_FIX,
4040
LINEAR_RETRIEVER_L2_NORM,
4141
LINEAR_RETRIEVER_MINSCORE_FIX,
42-
LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT
42+
LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
43+
RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT
4344
);
4445
}
4546
}

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

Lines changed: 130 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,28 @@
88
package org.elasticsearch.xpack.rank.rrf;
99

1010
import org.apache.lucene.search.ScoreDoc;
11+
import org.elasticsearch.action.ActionRequestValidationException;
12+
import org.elasticsearch.action.ResolvedIndices;
1113
import org.elasticsearch.common.ParsingException;
1214
import org.elasticsearch.common.util.Maps;
1315
import org.elasticsearch.features.NodeFeature;
16+
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
1417
import org.elasticsearch.index.query.QueryBuilder;
18+
import org.elasticsearch.index.query.QueryRewriteContext;
1519
import org.elasticsearch.license.LicenseUtils;
20+
import org.elasticsearch.search.builder.SearchSourceBuilder;
1621
import org.elasticsearch.search.rank.RankBuilder;
1722
import org.elasticsearch.search.rank.RankDoc;
1823
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
1924
import org.elasticsearch.search.retriever.RetrieverBuilder;
2025
import org.elasticsearch.search.retriever.RetrieverParserContext;
26+
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
2127
import org.elasticsearch.xcontent.ConstructingObjectParser;
2228
import org.elasticsearch.xcontent.ParseField;
2329
import org.elasticsearch.xcontent.XContentBuilder;
2430
import org.elasticsearch.xcontent.XContentParser;
2531
import org.elasticsearch.xpack.core.XPackPlugin;
32+
import org.elasticsearch.xpack.rank.MultiFieldsInnerRetrieverUtils;
2633

2734
import java.io.IOException;
2835
import java.util.ArrayList;
@@ -31,7 +38,6 @@
3138
import java.util.Map;
3239
import java.util.Objects;
3340

34-
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
3541
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
3642

3743
/**
@@ -42,13 +48,16 @@
4248
* formula.
4349
*/
4450
public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetrieverBuilder> {
51+
public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support");
4552

4653
public static final String NAME = "rrf";
4754
public static final NodeFeature RRF_RETRIEVER_SUPPORTED = new NodeFeature("rrf_retriever_supported", true);
4855
public static final NodeFeature RRF_RETRIEVER_COMPOSITION_SUPPORTED = new NodeFeature("rrf_retriever_composition_supported", true);
4956

5057
public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
5158
public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant");
59+
public static final ParseField FIELDS_FIELD = new ParseField("fields");
60+
public static final ParseField QUERY_FIELD = new ParseField("query");
5261

5362
public static final int DEFAULT_RANK_CONSTANT = 60;
5463
@SuppressWarnings("unchecked")
@@ -57,22 +66,29 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
5766
false,
5867
args -> {
5968
List<RetrieverBuilder> childRetrievers = (List<RetrieverBuilder>) args[0];
60-
List<RetrieverSource> innerRetrievers = childRetrievers.stream().map(RetrieverSource::from).toList();
61-
int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1];
62-
int rankConstant = args[2] == null ? DEFAULT_RANK_CONSTANT : (int) args[2];
63-
return new RRFRetrieverBuilder(innerRetrievers, rankWindowSize, rankConstant);
69+
List<String> fields = (List<String>) args[1];
70+
String query = (String) args[2];
71+
int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
72+
int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4];
73+
74+
List<RetrieverSource> innerRetrievers = childRetrievers != null
75+
? childRetrievers.stream().map(RetrieverSource::from).toList()
76+
: List.of();
77+
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
6478
}
6579
);
6680

6781
static {
68-
PARSER.declareObjectArray(constructorArg(), (p, c) -> {
82+
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> {
6983
p.nextToken();
7084
String name = p.currentName();
7185
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
7286
c.trackRetrieverUsage(retrieverBuilder.getName());
7387
p.nextToken();
7488
return retrieverBuilder;
7589
}, RETRIEVERS_FIELD);
90+
PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD);
91+
PARSER.declareString(optionalConstructorArg(), QUERY_FIELD);
7692
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
7793
PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD);
7894
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
@@ -91,25 +107,60 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
91107
return PARSER.apply(parser, context);
92108
}
93109

110+
private final List<String> fields;
111+
private final String query;
94112
private final int rankConstant;
95113

96-
public RRFRetrieverBuilder(int rankWindowSize, int rankConstant) {
97-
this(new ArrayList<>(), rankWindowSize, rankConstant);
114+
public RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
115+
this(childRetrievers, null, null, rankWindowSize, rankConstant);
98116
}
99117

100-
RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
101-
super(childRetrievers, rankWindowSize);
118+
public RRFRetrieverBuilder(
119+
List<RetrieverSource> childRetrievers,
120+
List<String> fields,
121+
String query,
122+
int rankWindowSize,
123+
int rankConstant
124+
) {
125+
// Use a mutable list for childRetrievers so that we can use addChild
126+
super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize);
127+
this.fields = fields == null ? List.of() : List.copyOf(fields);
128+
this.query = query;
102129
this.rankConstant = rankConstant;
103130
}
104131

132+
public int rankConstant() {
133+
return rankConstant;
134+
}
135+
105136
@Override
106137
public String getName() {
107138
return NAME;
108139
}
109140

141+
@Override
142+
public ActionRequestValidationException validate(
143+
SearchSourceBuilder source,
144+
ActionRequestValidationException validationException,
145+
boolean isScroll,
146+
boolean allowPartialSearchResults
147+
) {
148+
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
149+
return MultiFieldsInnerRetrieverUtils.validateParams(
150+
innerRetrievers,
151+
fields,
152+
query,
153+
getName(),
154+
RETRIEVERS_FIELD.getPreferredName(),
155+
FIELDS_FIELD.getPreferredName(),
156+
QUERY_FIELD.getPreferredName(),
157+
validationException
158+
);
159+
}
160+
110161
@Override
111162
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
112-
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
163+
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant);
113164
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
114165
clone.retrieverName = retrieverName;
115166
return clone;
@@ -172,17 +223,72 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults
172223
return topResults;
173224
}
174225

226+
@Override
227+
protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
228+
RetrieverBuilder rewritten = this;
229+
230+
ResolvedIndices resolvedIndices = ctx.getResolvedIndices();
231+
if (resolvedIndices != null && query != null) {
232+
// TODO: Refactor duplicate code
233+
// Using the multi-fields query format
234+
var localIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata();
235+
if (localIndicesMetadata.size() > 1) {
236+
throw new IllegalArgumentException(
237+
"[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying multiple indices"
238+
);
239+
} else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
240+
throw new IllegalArgumentException(
241+
"[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying remote indices"
242+
);
243+
}
244+
245+
List<RetrieverSource> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers(
246+
fields,
247+
query,
248+
localIndicesMetadata.values(),
249+
r -> {
250+
List<RetrieverSource> retrievers = r.stream()
251+
.map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource)
252+
.toList();
253+
return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant);
254+
},
255+
w -> {
256+
if (w != 1.0f) {
257+
throw new IllegalArgumentException(
258+
"[" + NAME + "] does not support per-field weights in [" + FIELDS_FIELD.getPreferredName() + "]"
259+
);
260+
}
261+
}
262+
).stream().map(RetrieverSource::from).toList();
263+
264+
if (fieldsInnerRetrievers.isEmpty() == false) {
265+
// TODO: This is a incomplete solution as it does not address other incomplete copy issues
266+
// (such as dropping the retriever name and min score)
267+
rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant);
268+
rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
269+
} else {
270+
// Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
271+
rewritten = new StandardRetrieverBuilder(new MatchNoneQueryBuilder());
272+
}
273+
}
274+
275+
return rewritten;
276+
}
277+
175278
// ---- FOR TESTING XCONTENT PARSING ----
176279

177280
@Override
178281
public boolean doEquals(Object o) {
179282
RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
180-
return super.doEquals(o) && rankConstant == that.rankConstant;
283+
return super.doEquals(o)
284+
&& Objects.equals(fields, that.fields)
285+
&& Objects.equals(query, that.query)
286+
&& rankConstant == that.rankConstant;
181287
}
182288

183289
@Override
184290
public int doHashCode() {
185-
return Objects.hash(super.doHashCode(), rankConstant);
291+
return Objects.hash(super.doHashCode(), fields, query, rankConstant);
186292
}
187293

188294
@Override
@@ -196,6 +302,17 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
196302
builder.endArray();
197303
}
198304

305+
if (fields.isEmpty() == false) {
306+
builder.startArray(FIELDS_FIELD.getPreferredName());
307+
for (String field : fields) {
308+
builder.value(field);
309+
}
310+
builder.endArray();
311+
}
312+
if (query != null) {
313+
builder.field(QUERY_FIELD.getPreferredName(), query);
314+
}
315+
199316
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
200317
builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant);
201318
}

x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.action.search.SearchRequest;
1111
import org.elasticsearch.common.Strings;
1212
import org.elasticsearch.search.builder.SearchSourceBuilder;
13+
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
1314
import org.elasticsearch.search.retriever.RetrieverBuilder;
1415
import org.elasticsearch.search.retriever.RetrieverParserContext;
1516
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
@@ -45,13 +46,22 @@ public static RRFRetrieverBuilder createRandomRRFRetrieverBuilder() {
4546
if (randomBoolean()) {
4647
rankConstant = randomIntBetween(1, 1000000);
4748
}
48-
var ret = new RRFRetrieverBuilder(rankWindowSize, rankConstant);
49+
50+
List<String> fields = null;
51+
String query = null;
52+
if (randomBoolean()) {
53+
fields = randomList(1, 10, () -> randomAlphaOfLengthBetween(1, 10));
54+
query = randomAlphaOfLengthBetween(1, 10);
55+
}
56+
4957
int retrieverCount = randomIntBetween(2, 50);
58+
List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = new ArrayList<>(retrieverCount);
5059
while (retrieverCount > 0) {
51-
ret.addChild(TestRetrieverBuilder.createRandomTestRetrieverBuilder());
60+
innerRetrievers.add(CompoundRetrieverBuilder.RetrieverSource.from(TestRetrieverBuilder.createRandomTestRetrieverBuilder()));
5261
--retrieverCount;
5362
}
54-
return ret;
63+
64+
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
5565
}
5666

5767
@Override
@@ -94,28 +104,32 @@ protected NamedXContentRegistry xContentRegistry() {
94104
}
95105

96106
public void testRRFRetrieverParsing() throws IOException {
97-
String restContent = "{"
98-
+ " \"retriever\": {"
99-
+ " \"rrf\": {"
100-
+ " \"retrievers\": ["
101-
+ " {"
102-
+ " \"test\": {"
103-
+ " \"value\": \"foo\""
104-
+ " }"
105-
+ " },"
106-
+ " {"
107-
+ " \"test\": {"
108-
+ " \"value\": \"bar\""
109-
+ " }"
110-
+ " }"
111-
+ " ],"
112-
+ " \"rank_window_size\": 100,"
113-
+ " \"rank_constant\": 10,"
114-
+ " \"min_score\": 20.0,"
115-
+ " \"_name\": \"foo_rrf\""
116-
+ " }"
117-
+ " }"
118-
+ "}";
107+
String restContent = """
108+
{
109+
"retriever": {
110+
"rrf": {
111+
"retrievers": [
112+
{
113+
"test": {
114+
"value": "foo"
115+
}
116+
},
117+
{
118+
"test": {
119+
"value": "bar"
120+
}
121+
}
122+
],
123+
"fields": ["field1", "field2"],
124+
"query": "baz",
125+
"rank_window_size": 100,
126+
"rank_constant": 10,
127+
"min_score": 20.0,
128+
"_name": "foo_rrf"
129+
}
130+
}
131+
}
132+
""";
119133
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
120134
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
121135
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);

0 commit comments

Comments
 (0)