Skip to content

Commit f8923c3

Browse files
committed
Check RRF retriever rank window size and rank constant propagation
1 parent 1cd8797 commit f8923c3

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import java.util.Map;
4040
import java.util.Set;
4141

42+
import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
4243
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.convertToRetrieverSource;
4344

4445
/** Tests for the rrf retriever. */
@@ -101,7 +102,7 @@ public void testSimplifiedParamsRewrite() {
101102
null,
102103
List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"),
103104
"foo",
104-
10,
105+
DEFAULT_RANK_WINDOW_SIZE,
105106
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT
106107
);
107108
assertSimplifiedParamsRewrite(
@@ -112,12 +113,28 @@ public void testSimplifiedParamsRewrite() {
112113
"foo"
113114
);
114115

116+
// Non-default rank window size and rank constant
117+
rrfRetrieverBuilder = new RRFRetrieverBuilder(
118+
null,
119+
List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"),
120+
"foo2",
121+
DEFAULT_RANK_WINDOW_SIZE * 2,
122+
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT / 2
123+
);
124+
assertSimplifiedParamsRewrite(
125+
rrfRetrieverBuilder,
126+
queryRewriteContext,
127+
Map.of("field_1", 1.0f, "field_2", 1.0f),
128+
Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f),
129+
"foo2"
130+
);
131+
115132
// Glob matching on inference and non-inference fields
116133
rrfRetrieverBuilder = new RRFRetrieverBuilder(
117134
null,
118135
List.of("field_*", "*_field_1"),
119136
"bar",
120-
10,
137+
DEFAULT_RANK_WINDOW_SIZE,
121138
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT
122139
);
123140
assertSimplifiedParamsRewrite(
@@ -129,7 +146,13 @@ public void testSimplifiedParamsRewrite() {
129146
);
130147

131148
// All-fields wildcard
132-
rrfRetrieverBuilder = new RRFRetrieverBuilder(null, List.of("*"), "baz", 10, RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT);
149+
rrfRetrieverBuilder = new RRFRetrieverBuilder(
150+
null,
151+
List.of("*"),
152+
"baz",
153+
DEFAULT_RANK_WINDOW_SIZE,
154+
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT
155+
);
133156
assertSimplifiedParamsRewrite(
134157
rrfRetrieverBuilder,
135158
queryRewriteContext,
@@ -208,13 +231,17 @@ private static void assertSimplifiedParamsRewrite(
208231

209232
RRFRetrieverBuilder rewritten = retriever.doRewrite(ctx);
210233
assertNotSame(retriever, rewritten);
234+
assertEquals(retriever.rankWindowSize(), rewritten.rankWindowSize());
235+
assertEquals(retriever.rankConstant(), rewritten.rankConstant());
211236
assertEquals(expectedInnerRetrievers, getInnerRetrieversAsSet(rewritten));
212237
}
213238

214239
private static Set<Object> getInnerRetrieversAsSet(RRFRetrieverBuilder retriever) {
215240
Set<Object> innerRetrieversSet = new HashSet<>();
216241
for (CompoundRetrieverBuilder.RetrieverSource innerRetriever : retriever.innerRetrievers()) {
217242
if (innerRetriever.retriever() instanceof RRFRetrieverBuilder innerRrfRetriever) {
243+
assertEquals(retriever.rankWindowSize(), innerRrfRetriever.rankWindowSize());
244+
assertEquals(retriever.rankConstant(), innerRrfRetriever.rankConstant());
218245
innerRetrieversSet.add(getInnerRetrieversAsSet(innerRrfRetriever));
219246
} else {
220247
innerRetrieversSet.add(innerRetriever);

0 commit comments

Comments
 (0)