|
34 | 34 | import org.elasticsearch.xcontent.json.JsonXContent; |
35 | 35 |
|
36 | 36 | import java.io.IOException; |
| 37 | +import java.util.HashSet; |
37 | 38 | import java.util.List; |
38 | 39 | import java.util.Map; |
| 40 | +import java.util.Set; |
39 | 41 |
|
40 | 42 | import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.convertToRetrieverSource; |
41 | 43 |
|
@@ -164,28 +166,36 @@ private static void assertSimplifiedParamsRewrite( |
164 | 166 | Map<String, Float> expectedInferenceFields, |
165 | 167 | String expectedQuery |
166 | 168 | ) { |
167 | | - List<CompoundRetrieverBuilder.RetrieverSource> expectedInnerRetrievers = List.of( |
| 169 | + Set<Object> expectedInnerRetrievers = Set.of( |
168 | 170 | convertToRetrieverSource( |
169 | 171 | new StandardRetrieverBuilder( |
170 | 172 | new MultiMatchQueryBuilder(expectedQuery).type(MultiMatchQueryBuilder.Type.MOST_FIELDS) |
171 | 173 | .fields(expectedNonInferenceFields) |
172 | 174 | ) |
173 | 175 | ), |
174 | | - convertToRetrieverSource(new RRFRetrieverBuilder(expectedInferenceFields.entrySet().stream().map(e -> { |
| 176 | + Set.of(expectedInferenceFields.entrySet().stream().map(e -> { |
175 | 177 | if (e.getValue() != 1.0f) { |
176 | 178 | throw new IllegalArgumentException("Cannot apply per-field weights in RRF"); |
177 | 179 | } |
178 | 180 | return convertToRetrieverSource(new StandardRetrieverBuilder(new MatchQueryBuilder(e.getKey(), expectedQuery))); |
179 | | - }).toList(), retriever.rankWindowSize(), retriever.rankConstant())) |
180 | | - ); |
181 | | - RRFRetrieverBuilder expectedRewritten = new RRFRetrieverBuilder( |
182 | | - expectedInnerRetrievers, |
183 | | - retriever.rankWindowSize(), |
184 | | - retriever.rankConstant() |
| 181 | + }).toArray()) |
185 | 182 | ); |
186 | 183 |
|
187 | 184 | RRFRetrieverBuilder rewritten = retriever.doRewrite(ctx); |
188 | 185 | assertNotSame(retriever, rewritten); |
189 | | - assertEquals(expectedRewritten, rewritten); |
| 186 | + assertEquals(expectedInnerRetrievers, getInnerRetrieversAsSet(rewritten)); |
| 187 | + } |
| 188 | + |
| 189 | + private static Set<Object> getInnerRetrieversAsSet(RRFRetrieverBuilder retriever) { |
| 190 | + Set<Object> innerRetrieversSet = new HashSet<>(); |
| 191 | + for (CompoundRetrieverBuilder.RetrieverSource innerRetriever : retriever.innerRetrievers()) { |
| 192 | + if (innerRetriever.retriever() instanceof RRFRetrieverBuilder innerRrfRetriever) { |
| 193 | + innerRetrieversSet.add(getInnerRetrieversAsSet(innerRrfRetriever)); |
| 194 | + } else { |
| 195 | + innerRetrieversSet.add(innerRetriever); |
| 196 | + } |
| 197 | + } |
| 198 | + |
| 199 | + return innerRetrieversSet; |
190 | 200 | } |
191 | 201 | } |
0 commit comments