Skip to content

Commit 9a3892f

Browse files
authored
Retrievers: Refactor retriever builder tests (#134799)
1 parent b6fa241 commit 9a3892f

File tree

3 files changed

+282
-305
lines changed

3 files changed

+282
-305
lines changed
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.search.retriever;
9+
10+
import org.elasticsearch.action.MockResolvedIndices;
11+
import org.elasticsearch.action.OriginalIndices;
12+
import org.elasticsearch.action.ResolvedIndices;
13+
import org.elasticsearch.action.support.IndicesOptions;
14+
import org.elasticsearch.cluster.metadata.IndexMetadata;
15+
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
16+
import org.elasticsearch.common.settings.Settings;
17+
import org.elasticsearch.core.Tuple;
18+
import org.elasticsearch.index.Index;
19+
import org.elasticsearch.index.IndexVersion;
20+
import org.elasticsearch.index.query.BoolQueryBuilder;
21+
import org.elasticsearch.index.query.MatchQueryBuilder;
22+
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
23+
import org.elasticsearch.index.query.QueryBuilder;
24+
import org.elasticsearch.index.query.QueryRewriteContext;
25+
import org.elasticsearch.index.query.TermsQueryBuilder;
26+
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder.RetrieverSource;
27+
import org.elasticsearch.test.ESTestCase;
28+
import org.elasticsearch.xpack.rank.linear.ScoreNormalizer;
29+
30+
import java.util.HashMap;
31+
import java.util.HashSet;
32+
import java.util.List;
33+
import java.util.Map;
34+
import java.util.Objects;
35+
import java.util.Set;
36+
import java.util.stream.Collectors;
37+
38+
public abstract class AbstractRetrieverBuilderTests<T extends CompoundRetrieverBuilder<T>> extends ESTestCase {
39+
40+
protected abstract float[] getWeights(T builder);
41+
42+
protected abstract ScoreNormalizer[] getScoreNormalizers(T builder);
43+
44+
protected abstract void assertCompoundRetriever(T originalRetriever, RetrieverBuilder rewrittenRetriever);
45+
46+
protected static ResolvedIndices createMockResolvedIndices(
47+
Map<String, List<String>> localIndexInferenceFields,
48+
Map<String, String> remoteIndexNames,
49+
Map<String, String> commonInferenceIds
50+
) {
51+
Map<Index, IndexMetadata> indexMetadata = new HashMap<>();
52+
53+
for (var indexEntry : localIndexInferenceFields.entrySet()) {
54+
String indexName = indexEntry.getKey();
55+
List<String> inferenceFields = indexEntry.getValue();
56+
57+
Index index = new Index(indexName, randomAlphaOfLength(10));
58+
59+
IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder(index.getName())
60+
.settings(
61+
Settings.builder()
62+
.put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())
63+
.put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID())
64+
)
65+
.numberOfShards(1)
66+
.numberOfReplicas(0);
67+
68+
for (String inferenceField : inferenceFields) {
69+
String inferenceId = commonInferenceIds.containsKey(inferenceField)
70+
? commonInferenceIds.get(inferenceField)
71+
: randomAlphaOfLengthBetween(3, 5);
72+
73+
indexMetadataBuilder.putInferenceField(
74+
new InferenceFieldMetadata(inferenceField, inferenceId, new String[] { inferenceField }, null)
75+
);
76+
}
77+
78+
indexMetadata.put(index, indexMetadataBuilder.build());
79+
}
80+
81+
Map<String, OriginalIndices> remoteIndices = new HashMap<>();
82+
if (remoteIndexNames != null) {
83+
for (Map.Entry<String, String> entry : remoteIndexNames.entrySet()) {
84+
remoteIndices.put(entry.getKey(), new OriginalIndices(new String[] { entry.getValue() }, IndicesOptions.DEFAULT));
85+
}
86+
}
87+
88+
return new MockResolvedIndices(
89+
remoteIndices,
90+
new OriginalIndices(localIndexInferenceFields.keySet().toArray(new String[0]), IndicesOptions.DEFAULT),
91+
indexMetadata
92+
);
93+
}
94+
95+
protected void assertMultiFieldsParamsRewrite(
96+
T retriever,
97+
QueryRewriteContext ctx,
98+
Map<String, Float> expectedNonInferenceFields,
99+
Map<String, Float> expectedInferenceFields,
100+
String expectedQuery,
101+
ScoreNormalizer expectedNormalizer
102+
) {
103+
Map<Tuple<String, List<String>>, Float> inferenceFields = new HashMap<>();
104+
expectedInferenceFields.forEach((key, value) -> inferenceFields.put(new Tuple<>(key, List.of()), value));
105+
106+
assertMultiIndexMultiFieldsParamsRewrite(
107+
retriever,
108+
ctx,
109+
Map.of(expectedNonInferenceFields, List.of()),
110+
inferenceFields,
111+
expectedQuery,
112+
expectedNormalizer
113+
);
114+
}
115+
116+
@SuppressWarnings("unchecked")
117+
protected void assertMultiIndexMultiFieldsParamsRewrite(
118+
T retriever,
119+
QueryRewriteContext ctx,
120+
Map<Map<String, Float>, List<String>> expectedNonInferenceFields,
121+
Map<Tuple<String, List<String>>, Float> expectedInferenceFields,
122+
String expectedQuery,
123+
ScoreNormalizer expectedNormalizer
124+
) {
125+
Set<QueryBuilder> expectedLexicalQueryBuilders = expectedNonInferenceFields.entrySet().stream().map(entry -> {
126+
Map<String, Float> fields = entry.getKey();
127+
List<String> indices = entry.getValue();
128+
129+
QueryBuilder queryBuilder = new MultiMatchQueryBuilder(expectedQuery).type(MultiMatchQueryBuilder.Type.MOST_FIELDS)
130+
.fields(fields);
131+
132+
if (indices.isEmpty() == false) {
133+
queryBuilder = new BoolQueryBuilder().must(queryBuilder).filter(new TermsQueryBuilder("_index", indices));
134+
}
135+
return queryBuilder;
136+
}).collect(Collectors.toSet());
137+
138+
Set<InnerRetriever> expectedInnerSemanticRetrievers = expectedInferenceFields.entrySet().stream().map(entry -> {
139+
var groupedInferenceField = entry.getKey();
140+
var fieldName = groupedInferenceField.v1();
141+
var indices = groupedInferenceField.v2();
142+
var weight = entry.getValue();
143+
QueryBuilder queryBuilder = new MatchQueryBuilder(fieldName, expectedQuery);
144+
if (indices.isEmpty() == false) {
145+
queryBuilder = new BoolQueryBuilder().must(queryBuilder).filter(new TermsQueryBuilder("_index", indices));
146+
}
147+
return new InnerRetriever(new StandardRetrieverBuilder(queryBuilder), weight, expectedNormalizer);
148+
}).collect(Collectors.toSet());
149+
150+
RetrieverBuilder rewritten = retriever.doRewrite(ctx);
151+
assertNotSame(retriever, rewritten);
152+
assertCompoundRetriever(retriever, rewritten);
153+
154+
boolean assertedLexical = false;
155+
boolean assertedSemantic = false;
156+
157+
for (InnerRetriever topInnerRetriever : getInnerRetrieversAsSet(retriever, (T) rewritten)) {
158+
assertEquals(expectedNormalizer, topInnerRetriever.normalizer);
159+
assertEquals(1.0f, topInnerRetriever.weight, 0.0f);
160+
161+
if (topInnerRetriever.retriever instanceof StandardRetrieverBuilder standardRetrieverBuilder) {
162+
assertFalse("the lexical retriever is only asserted once", assertedLexical);
163+
assertFalse(expectedNonInferenceFields.isEmpty());
164+
165+
QueryBuilder topDocsQueryBuilder = standardRetrieverBuilder.topDocsQuery();
166+
if (expectedLexicalQueryBuilders.size() == 1) {
167+
assertEquals(topDocsQueryBuilder, expectedLexicalQueryBuilders.iterator().next());
168+
} else {
169+
assertTrue(topDocsQueryBuilder instanceof BoolQueryBuilder);
170+
BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) topDocsQueryBuilder;
171+
assertEquals(new HashSet<>(expectedLexicalQueryBuilders), new HashSet<>(boolQueryBuilder.should()));
172+
}
173+
assertedLexical = true;
174+
} else {
175+
assertFalse("the semantic retriever is only asserted once", assertedSemantic);
176+
assertFalse(expectedInferenceFields.isEmpty());
177+
assertEquals(expectedInnerSemanticRetrievers, topInnerRetriever.retriever);
178+
assertedSemantic = true;
179+
}
180+
}
181+
}
182+
183+
@SuppressWarnings("unchecked")
184+
private Set<InnerRetriever> getInnerRetrieversAsSet(T originalRetriever, T rewrittenRetriever) {
185+
float[] weights = getWeights(rewrittenRetriever);
186+
ScoreNormalizer[] normalizers = getScoreNormalizers(rewrittenRetriever);
187+
188+
int i = 0;
189+
Set<InnerRetriever> innerRetrieversSet = new HashSet<>();
190+
for (RetrieverSource innerRetriever : rewrittenRetriever.innerRetrievers()) {
191+
float weight = weights[i];
192+
ScoreNormalizer normalizer = normalizers != null ? normalizers[i] : null;
193+
194+
if (innerRetriever.retriever() instanceof CompoundRetrieverBuilder<?> compoundRetriever) {
195+
assertCompoundRetriever(originalRetriever, compoundRetriever);
196+
innerRetrieversSet.add(
197+
new InnerRetriever(getInnerRetrieversAsSet(originalRetriever, (T) compoundRetriever), weight, normalizer)
198+
);
199+
} else {
200+
innerRetrieversSet.add(new InnerRetriever(innerRetriever.retriever(), weight, normalizer));
201+
}
202+
203+
i++;
204+
}
205+
206+
return innerRetrieversSet;
207+
}
208+
209+
private static class InnerRetriever {
210+
private final Object retriever;
211+
private final float weight;
212+
private final ScoreNormalizer normalizer;
213+
214+
InnerRetriever(RetrieverBuilder retriever, float weight, ScoreNormalizer normalizer) {
215+
this.retriever = retriever;
216+
this.weight = weight;
217+
this.normalizer = normalizer;
218+
}
219+
220+
InnerRetriever(Set<InnerRetriever> innerRetrievers, float weight, ScoreNormalizer normalizer) {
221+
this.retriever = innerRetrievers;
222+
this.weight = weight;
223+
this.normalizer = normalizer;
224+
}
225+
226+
@Override
227+
public boolean equals(Object o) {
228+
if (this == o) return true;
229+
if (o == null || getClass() != o.getClass()) return false;
230+
InnerRetriever that = (InnerRetriever) o;
231+
return Float.compare(weight, that.weight) == 0
232+
&& Objects.equals(retriever, that.retriever)
233+
&& Objects.equals(normalizer, that.normalizer);
234+
}
235+
236+
@Override
237+
public int hashCode() {
238+
return Objects.hash(retriever, weight, normalizer);
239+
}
240+
}
241+
}

0 commit comments

Comments
 (0)