Skip to content

Commit 397a1b8

Browse files
committed
Updated RRFRetrieverBuilderParsingTests
1 parent 22594a5 commit 397a1b8

File tree

2 files changed

+44
-31
lines changed

2 files changed

+44
-31
lines changed

server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,10 @@ public int rankWindowSize() {
295295
return rankWindowSize;
296296
}
297297

298+
public static RetrieverSource convertToRetrieverSource(RetrieverBuilder retrieverBuilder) {
299+
return new RetrieverSource(retrieverBuilder, null);
300+
}
301+
298302
protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
299303
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
300304
.trackTotalHits(false)
@@ -332,10 +336,6 @@ protected T doRewrite(QueryRewriteContext ctx) {
332336
return (T) this;
333337
}
334338

335-
protected static RetrieverSource convertToRetrieverSource(RetrieverBuilder retrieverBuilder) {
336-
return new RetrieverSource(retrieverBuilder, null);
337-
}
338-
339339
private RankDoc[] getRankDocs(SearchResponse searchResponse) {
340340
int size = searchResponse.getHits().getHits().length;
341341
RankDoc[] docs = new RankDoc[size];

x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.action.search.SearchRequest;
1111
import org.elasticsearch.common.Strings;
1212
import org.elasticsearch.search.builder.SearchSourceBuilder;
13+
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
1314
import org.elasticsearch.search.retriever.RetrieverBuilder;
1415
import org.elasticsearch.search.retriever.RetrieverParserContext;
1516
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
@@ -26,11 +27,10 @@
2627
import java.util.ArrayList;
2728
import java.util.List;
2829

30+
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.convertToRetrieverSource;
2931
import static org.hamcrest.Matchers.equalTo;
3032
import static org.hamcrest.Matchers.instanceOf;
3133

32-
// TODO: Add simplified format tests?
33-
3434
public class RRFRetrieverBuilderParsingTests extends AbstractXContentTestCase<RRFRetrieverBuilder> {
3535

3636
/**
@@ -47,13 +47,22 @@ public static RRFRetrieverBuilder createRandomRRFRetrieverBuilder() {
4747
if (randomBoolean()) {
4848
rankConstant = randomIntBetween(1, 1000000);
4949
}
50-
var ret = new RRFRetrieverBuilder(rankWindowSize, rankConstant);
50+
51+
List<String> fields = null;
52+
String query = null;
53+
if (randomBoolean()) {
54+
fields = randomList(1, 10, () -> randomAlphaOfLengthBetween(1, 10));
55+
query = randomAlphaOfLengthBetween(1, 10);
56+
}
57+
5158
int retrieverCount = randomIntBetween(2, 50);
59+
List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = new ArrayList<>(retrieverCount);
5260
while (retrieverCount > 0) {
53-
ret.addChild(TestRetrieverBuilder.createRandomTestRetrieverBuilder());
61+
innerRetrievers.add(convertToRetrieverSource(TestRetrieverBuilder.createRandomTestRetrieverBuilder()));
5462
--retrieverCount;
5563
}
56-
return ret;
64+
65+
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
5766
}
5867

5968
@Override
@@ -96,28 +105,32 @@ protected NamedXContentRegistry xContentRegistry() {
96105
}
97106

98107
public void testRRFRetrieverParsing() throws IOException {
99-
String restContent = "{"
100-
+ " \"retriever\": {"
101-
+ " \"rrf\": {"
102-
+ " \"retrievers\": ["
103-
+ " {"
104-
+ " \"test\": {"
105-
+ " \"value\": \"foo\""
106-
+ " }"
107-
+ " },"
108-
+ " {"
109-
+ " \"test\": {"
110-
+ " \"value\": \"bar\""
111-
+ " }"
112-
+ " }"
113-
+ " ],"
114-
+ " \"rank_window_size\": 100,"
115-
+ " \"rank_constant\": 10,"
116-
+ " \"min_score\": 20.0,"
117-
+ " \"_name\": \"foo_rrf\""
118-
+ " }"
119-
+ " }"
120-
+ "}";
108+
String restContent = """
109+
{
110+
"retriever": {
111+
"rrf": {
112+
"retrievers": [
113+
{
114+
"test": {
115+
"value": "foo"
116+
}
117+
},
118+
{
119+
"test": {
120+
"value": "bar"
121+
}
122+
}
123+
],
124+
"fields": ["field1", "field2"],
125+
"query": "baz",
126+
"rank_window_size": 100,
127+
"rank_constant": 10,
128+
"min_score": 20.0,
129+
"_name": "foo_rrf"
130+
}
131+
}
132+
}
133+
""";
121134
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
122135
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {
123136
SearchSourceBuilder source = new SearchSourceBuilder().parseXContent(jsonParser, true, searchUsageHolder, nf -> true);

0 commit comments

Comments
 (0)