Skip to content

Commit d43ac88

Browse files
committed
Fixed rewrite tests
1 parent e7b5476 commit d43ac88

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import java.io.IOException;
3838
import java.util.ArrayList;
39+
import java.util.Collections;
3940
import java.util.List;
4041
import java.util.Locale;
4142
import java.util.Objects;
@@ -295,6 +296,10 @@ public int rankWindowSize() {
295296
return rankWindowSize;
296297
}
297298

299+
public List<RetrieverSource> innerRetrievers() {
300+
return Collections.unmodifiableList(innerRetrievers);
301+
}
302+
298303
public static RetrieverSource convertToRetrieverSource(RetrieverBuilder retrieverBuilder) {
299304
return new RetrieverSource(retrieverBuilder, null);
300305
}

x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
import org.elasticsearch.xcontent.json.JsonXContent;
3535

3636
import java.io.IOException;
37+
import java.util.HashSet;
3738
import java.util.List;
3839
import java.util.Map;
40+
import java.util.Set;
3941

4042
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.convertToRetrieverSource;
4143

@@ -164,28 +166,36 @@ private static void assertSimplifiedParamsRewrite(
164166
Map<String, Float> expectedInferenceFields,
165167
String expectedQuery
166168
) {
167-
List<CompoundRetrieverBuilder.RetrieverSource> expectedInnerRetrievers = List.of(
169+
Set<Object> expectedInnerRetrievers = Set.of(
168170
convertToRetrieverSource(
169171
new StandardRetrieverBuilder(
170172
new MultiMatchQueryBuilder(expectedQuery).type(MultiMatchQueryBuilder.Type.MOST_FIELDS)
171173
.fields(expectedNonInferenceFields)
172174
)
173175
),
174-
convertToRetrieverSource(new RRFRetrieverBuilder(expectedInferenceFields.entrySet().stream().map(e -> {
176+
Set.of(expectedInferenceFields.entrySet().stream().map(e -> {
175177
if (e.getValue() != 1.0f) {
176178
throw new IllegalArgumentException("Cannot apply per-field weights in RRF");
177179
}
178180
return convertToRetrieverSource(new StandardRetrieverBuilder(new MatchQueryBuilder(e.getKey(), expectedQuery)));
179-
}).toList(), retriever.rankWindowSize(), retriever.rankConstant()))
180-
);
181-
RRFRetrieverBuilder expectedRewritten = new RRFRetrieverBuilder(
182-
expectedInnerRetrievers,
183-
retriever.rankWindowSize(),
184-
retriever.rankConstant()
181+
}).toArray())
185182
);
186183

187184
RRFRetrieverBuilder rewritten = retriever.doRewrite(ctx);
188185
assertNotSame(retriever, rewritten);
189-
assertEquals(expectedRewritten, rewritten);
186+
assertEquals(expectedInnerRetrievers, getInnerRetrieversAsSet(rewritten));
187+
}
188+
189+
private static Set<Object> getInnerRetrieversAsSet(RRFRetrieverBuilder retriever) {
190+
Set<Object> innerRetrieversSet = new HashSet<>();
191+
for (CompoundRetrieverBuilder.RetrieverSource innerRetriever : retriever.innerRetrievers()) {
192+
if (innerRetriever.retriever() instanceof RRFRetrieverBuilder innerRrfRetriever) {
193+
innerRetrieversSet.add(getInnerRetrieversAsSet(innerRrfRetriever));
194+
} else {
195+
innerRetrieversSet.add(innerRetriever);
196+
}
197+
}
198+
199+
return innerRetrieversSet;
190200
}
191201
}

0 commit comments

Comments
 (0)