Skip to content

Commit 1cd8797

Browse files
committed
Check linear retriever rank window size propagation
1 parent 1240aa1 commit 1cd8797

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilderTests.java

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import java.util.Set;
3434
import java.util.stream.Collectors;
3535

36+
import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
37+
3638
public class LinearRetrieverBuilderTests extends ESTestCase {
3739
public void testSimplifiedParamsRewrite() {
3840
final String indexName = "test-index";
@@ -53,7 +55,7 @@ public void testSimplifiedParamsRewrite() {
5355
List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"),
5456
"foo",
5557
MinMaxScoreNormalizer.INSTANCE,
56-
10,
58+
DEFAULT_RANK_WINDOW_SIZE,
5759
new float[0],
5860
new ScoreNormalizer[0]
5961
);
@@ -66,13 +68,32 @@ public void testSimplifiedParamsRewrite() {
6668
MinMaxScoreNormalizer.INSTANCE
6769
);
6870

71+
// Non-default rank window size
72+
retriever = new LinearRetrieverBuilder(
73+
null,
74+
List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"),
75+
"foo2",
76+
MinMaxScoreNormalizer.INSTANCE,
77+
DEFAULT_RANK_WINDOW_SIZE * 2,
78+
new float[0],
79+
new ScoreNormalizer[0]
80+
);
81+
assertSimplifiedParamsRewrite(
82+
retriever,
83+
queryRewriteContext,
84+
Map.of("field_1", 1.0f, "field_2", 1.0f),
85+
Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f),
86+
"foo2",
87+
MinMaxScoreNormalizer.INSTANCE
88+
);
89+
6990
// No wildcards, per-field boosting
7091
retriever = new LinearRetrieverBuilder(
7192
null,
7293
List.of("field_1", "field_2^1.5", "semantic_field_1", "semantic_field_2^2"),
7394
"bar",
7495
MinMaxScoreNormalizer.INSTANCE,
75-
10,
96+
DEFAULT_RANK_WINDOW_SIZE,
7697
new float[0],
7798
new ScoreNormalizer[0]
7899
);
@@ -91,7 +112,7 @@ public void testSimplifiedParamsRewrite() {
91112
List.of("field_*^1.5", "*_field_1^2.5"),
92113
"baz",
93114
MinMaxScoreNormalizer.INSTANCE,
94-
10,
115+
DEFAULT_RANK_WINDOW_SIZE,
95116
new float[0],
96117
new ScoreNormalizer[0]
97118
);
@@ -110,7 +131,7 @@ public void testSimplifiedParamsRewrite() {
110131
List.of("*"),
111132
"qux",
112133
MinMaxScoreNormalizer.INSTANCE,
113-
10,
134+
DEFAULT_RANK_WINDOW_SIZE,
114135
new float[0],
115136
new ScoreNormalizer[0]
116137
);
@@ -183,6 +204,7 @@ private static void assertSimplifiedParamsRewrite(
183204

184205
LinearRetrieverBuilder rewritten = retriever.doRewrite(ctx);
185206
assertNotSame(retriever, rewritten);
207+
assertEquals(retriever.rankWindowSize(), rewritten.rankWindowSize());
186208
assertEquals(expectedInnerRetrievers, getInnerRetrieversAsSet(rewritten));
187209
}
188210

@@ -197,6 +219,7 @@ private static Set<InnerRetriever> getInnerRetrieversAsSet(LinearRetrieverBuilde
197219
ScoreNormalizer normalizer = normalizers[i];
198220

199221
if (innerRetriever.retriever() instanceof LinearRetrieverBuilder innerLinearRetriever) {
222+
assertEquals(retriever.rankWindowSize(), innerLinearRetriever.rankWindowSize());
200223
innerRetrieversSet.add(new InnerRetriever(getInnerRetrieversAsSet(innerLinearRetriever), weight, normalizer));
201224
} else {
202225
innerRetrieversSet.add(new InnerRetriever(innerRetriever.retriever(), weight, normalizer));

0 commit comments

Comments
 (0)