Skip to content

Commit 34040c9

Browse files
committed
Apply RRF retriever changes from elastic#128633
1 parent 4275bc7 commit 34040c9

File tree

3 files changed

+379
-40
lines changed

3 files changed

+379
-40
lines changed

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

Lines changed: 128 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,26 @@
88
package org.elasticsearch.xpack.rank.rrf;
99

1010
import org.apache.lucene.search.ScoreDoc;
11+
import org.elasticsearch.action.ActionRequestValidationException;
12+
import org.elasticsearch.action.ResolvedIndices;
1113
import org.elasticsearch.common.util.Maps;
14+
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
1215
import org.elasticsearch.index.query.QueryBuilder;
16+
import org.elasticsearch.index.query.QueryRewriteContext;
1317
import org.elasticsearch.license.LicenseUtils;
18+
import org.elasticsearch.search.builder.SearchSourceBuilder;
1419
import org.elasticsearch.search.rank.RankBuilder;
1520
import org.elasticsearch.search.rank.RankDoc;
1621
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
1722
import org.elasticsearch.search.retriever.RetrieverBuilder;
1823
import org.elasticsearch.search.retriever.RetrieverParserContext;
24+
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
1925
import org.elasticsearch.xcontent.ConstructingObjectParser;
2026
import org.elasticsearch.xcontent.ParseField;
2127
import org.elasticsearch.xcontent.XContentBuilder;
2228
import org.elasticsearch.xcontent.XContentParser;
2329
import org.elasticsearch.xpack.core.XPackPlugin;
30+
import org.elasticsearch.xpack.rank.simplified.SimplifiedInnerRetrieverUtils;
2431

2532
import java.io.IOException;
2633
import java.util.ArrayList;
@@ -29,7 +36,6 @@
2936
import java.util.Map;
3037
import java.util.Objects;
3138

32-
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
3339
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
3440

3541
/**
@@ -45,6 +51,8 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
4551

4652
public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
4753
public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant");
54+
public static final ParseField FIELDS_FIELD = new ParseField("fields");
55+
public static final ParseField QUERY_FIELD = new ParseField("query");
4856

4957
public static final int DEFAULT_RANK_CONSTANT = 60;
5058
@SuppressWarnings("unchecked")
@@ -53,22 +61,29 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
5361
false,
5462
args -> {
5563
List<RetrieverBuilder> childRetrievers = (List<RetrieverBuilder>) args[0];
56-
List<RetrieverSource> innerRetrievers = childRetrievers.stream().map(RetrieverSource::from).toList();
57-
int rankWindowSize = args[1] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[1];
58-
int rankConstant = args[2] == null ? DEFAULT_RANK_CONSTANT : (int) args[2];
59-
return new RRFRetrieverBuilder(innerRetrievers, rankWindowSize, rankConstant);
64+
List<String> fields = (List<String>) args[1];
65+
String query = (String) args[2];
66+
int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
67+
int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4];
68+
69+
List<RetrieverSource> innerRetrievers = childRetrievers != null
70+
? childRetrievers.stream().map(r -> new RetrieverSource(r, null)).toList()
71+
: List.of();
72+
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
6073
}
6174
);
6275

6376
static {
64-
PARSER.declareObjectArray(constructorArg(), (p, c) -> {
77+
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> {
6578
p.nextToken();
6679
String name = p.currentName();
6780
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
6881
c.trackRetrieverUsage(retrieverBuilder.getName());
6982
p.nextToken();
7083
return retrieverBuilder;
7184
}, RETRIEVERS_FIELD);
85+
PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD);
86+
PARSER.declareString(optionalConstructorArg(), QUERY_FIELD);
7287
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
7388
PARSER.declareInt(optionalConstructorArg(), RANK_CONSTANT_FIELD);
7489
RetrieverBuilder.declareBaseParserFields(PARSER);
@@ -81,25 +96,63 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
8196
return PARSER.apply(parser, context);
8297
}
8398

99+
private final List<String> fields;
100+
private final String query;
84101
private final int rankConstant;
85102

86-
public RRFRetrieverBuilder(int rankWindowSize, int rankConstant) {
87-
this(new ArrayList<>(), rankWindowSize, rankConstant);
103+
public RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
104+
this(childRetrievers, null, null, rankWindowSize, rankConstant);
88105
}
89106

90-
RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
91-
super(childRetrievers, rankWindowSize);
107+
public RRFRetrieverBuilder(
108+
List<RetrieverSource> childRetrievers,
109+
List<String> fields,
110+
String query,
111+
int rankWindowSize,
112+
int rankConstant
113+
) {
114+
// Use a mutable list for childRetrievers so that we can use addChild
115+
super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize);
116+
this.fields = fields == null ? List.of() : List.copyOf(fields);
117+
this.query = query;
92118
this.rankConstant = rankConstant;
119+
120+
// TODO: Validate simplified query format args here?
121+
// Otherwise some of the validation is skipped when creating the retriever programmatically.
122+
}
123+
124+
public int rankConstant() {
125+
return rankConstant;
93126
}
94127

95128
@Override
96129
public String getName() {
97130
return NAME;
98131
}
99132

133+
@Override
134+
public ActionRequestValidationException validate(
135+
SearchSourceBuilder source,
136+
ActionRequestValidationException validationException,
137+
boolean isScroll,
138+
boolean allowPartialSearchResults
139+
) {
140+
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
141+
return SimplifiedInnerRetrieverUtils.validateSimplifiedFormatParams(
142+
innerRetrievers,
143+
fields,
144+
query,
145+
getName(),
146+
RETRIEVERS_FIELD.getPreferredName(),
147+
FIELDS_FIELD.getPreferredName(),
148+
QUERY_FIELD.getPreferredName(),
149+
validationException
150+
);
151+
}
152+
100153
@Override
101154
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
102-
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
155+
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant);
103156
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
104157
clone.retrieverName = retrieverName;
105158
return clone;
@@ -162,17 +215,68 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults
162215
return topResults;
163216
}
164217

218+
@Override
219+
protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
220+
RetrieverBuilder rewritten = this;
221+
222+
ResolvedIndices resolvedIndices = ctx.getResolvedIndices();
223+
if (resolvedIndices != null && query != null) {
224+
// Using the simplified query format
225+
var localIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata();
226+
if (localIndicesMetadata.size() > 1) {
227+
throw new IllegalArgumentException(
228+
"[" + NAME + "] does not support the simplified query format when querying multiple indices"
229+
);
230+
} else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
231+
throw new IllegalArgumentException(
232+
"[" + NAME + "] does not support the simplified query format when querying remote indices"
233+
);
234+
}
235+
236+
List<RetrieverSource> fieldsInnerRetrievers = SimplifiedInnerRetrieverUtils.generateInnerRetrievers(
237+
fields,
238+
query,
239+
localIndicesMetadata.values(),
240+
r -> {
241+
List<RetrieverSource> retrievers = r.stream()
242+
.map(SimplifiedInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource)
243+
.toList();
244+
return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant);
245+
},
246+
w -> {
247+
if (w != 1.0f) {
248+
throw new IllegalArgumentException(
249+
"[" + NAME + "] does not support per-field weights in [" + FIELDS_FIELD.getPreferredName() + "]"
250+
);
251+
}
252+
}
253+
).stream().map(CompoundRetrieverBuilder::convertToRetrieverSource).toList();
254+
255+
if (fieldsInnerRetrievers.isEmpty() == false) {
256+
rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant);
257+
} else {
258+
// Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
259+
rewritten = new StandardRetrieverBuilder(new MatchNoneQueryBuilder());
260+
}
261+
}
262+
263+
return rewritten;
264+
}
265+
165266
// ---- FOR TESTING XCONTENT PARSING ----
166267

167268
@Override
168269
public boolean doEquals(Object o) {
169270
RRFRetrieverBuilder that = (RRFRetrieverBuilder) o;
170-
return super.doEquals(o) && rankConstant == that.rankConstant;
271+
return super.doEquals(o)
272+
&& Objects.equals(fields, that.fields)
273+
&& Objects.equals(query, that.query)
274+
&& rankConstant == that.rankConstant;
171275
}
172276

173277
@Override
174278
public int doHashCode() {
175-
return Objects.hash(super.doHashCode(), rankConstant);
279+
return Objects.hash(super.doHashCode(), fields, query, rankConstant);
176280
}
177281

178282
@Override
@@ -186,6 +290,17 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
186290
builder.endArray();
187291
}
188292

293+
if (fields.isEmpty() == false) {
294+
builder.startArray(FIELDS_FIELD.getPreferredName());
295+
for (String field : fields) {
296+
builder.value(field);
297+
}
298+
builder.endArray();
299+
}
300+
if (query != null) {
301+
builder.field(QUERY_FIELD.getPreferredName(), query);
302+
}
303+
189304
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
190305
builder.field(RANK_CONSTANT_FIELD.getPreferredName(), rankConstant);
191306
}

x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.action.search.SearchRequest;
1111
import org.elasticsearch.common.Strings;
1212
import org.elasticsearch.search.builder.SearchSourceBuilder;
13+
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
1314
import org.elasticsearch.search.retriever.RetrieverBuilder;
1415
import org.elasticsearch.search.retriever.RetrieverParserContext;
1516
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
@@ -26,6 +27,7 @@
2627
import java.util.ArrayList;
2728
import java.util.List;
2829

30+
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.convertToRetrieverSource;
2931
import static org.hamcrest.Matchers.equalTo;
3032
import static org.hamcrest.Matchers.instanceOf;
3133

@@ -45,13 +47,22 @@ public static RRFRetrieverBuilder createRandomRRFRetrieverBuilder() {
4547
if (randomBoolean()) {
4648
rankConstant = randomIntBetween(1, 1000000);
4749
}
48-
var ret = new RRFRetrieverBuilder(rankWindowSize, rankConstant);
50+
51+
List<String> fields = null;
52+
String query = null;
53+
if (randomBoolean()) {
54+
fields = randomList(1, 10, () -> randomAlphaOfLengthBetween(1, 10));
55+
query = randomAlphaOfLengthBetween(1, 10);
56+
}
57+
4958
int retrieverCount = randomIntBetween(2, 50);
59+
List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = new ArrayList<>(retrieverCount);
5060
while (retrieverCount > 0) {
51-
ret.addChild(TestRetrieverBuilder.createRandomTestRetrieverBuilder());
61+
innerRetrievers.add(convertToRetrieverSource(TestRetrieverBuilder.createRandomTestRetrieverBuilder()));
5262
--retrieverCount;
5363
}
54-
return ret;
64+
65+
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
5566
}
5667

5768
@Override
@@ -94,28 +105,32 @@ protected NamedXContentRegistry xContentRegistry() {
94105
}
95106

96107
public void testRRFRetrieverParsing() throws IOException {
97-
String restContent = "{"
98-
+ " \"retriever\": {"
99-
+ " \"rrf\": {"
100-
+ " \"retrievers\": ["
101-
+ " {"
102-
+ " \"test\": {"
103-
+ " \"value\": \"foo\""
104-
+ " }"
105-
+ " },"
106-
+ " {"
107-
+ " \"test\": {"
108-
+ " \"value\": \"bar\""
109-
+ " }"
110-
+ " }"
111-
+ " ],"
112-
+ " \"rank_window_size\": 100,"
113-
+ " \"rank_constant\": 10,"
114-
+ " \"min_score\": 20.0,"
115-
+ " \"_name\": \"foo_rrf\""
116-
+ " }"
117-
+ " }"
118-
+ "}";
108+
String restContent = """
109+
{
110+
"retriever": {
111+
"rrf": {
112+
"retrievers": [
113+
{
114+
"test": {
115+
"value": "foo"
116+
}
117+
},
118+
{
119+
"test": {
120+
"value": "bar"
121+
}
122+
}
123+
],
124+
"fields": ["field1", "field2"],
125+
"query": "baz",
126+
"rank_window_size": 100,
127+
"rank_constant": 10,
128+
"min_score": 20.0,
129+
"_name": "foo_rrf"
130+
}
131+
}
132+
}
133+
""";
119134
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
120135
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
121136
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);

0 commit comments

Comments
 (0)