Skip to content

Commit 1ce5a9a

Browse files
committed
Added linear retriever rewrite tests
1 parent 6b4c04b commit 1ce5a9a

File tree

2 files changed

+251
-0
lines changed

2 files changed

+251
-0
lines changed

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverBuilder.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,14 @@ public String getName() {
336336
return NAME;
337337
}
338338

339+
float[] getWeights() {
340+
return weights;
341+
}
342+
343+
ScoreNormalizer[] getNormalizers() {
344+
return normalizers;
345+
}
346+
339347
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
340348
int index = 0;
341349
if (innerRetrievers.isEmpty() == false) {
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.rank.linear;
9+
10+
import org.elasticsearch.action.MockResolvedIndices;
11+
import org.elasticsearch.action.OriginalIndices;
12+
import org.elasticsearch.action.ResolvedIndices;
13+
import org.elasticsearch.action.support.IndicesOptions;
14+
import org.elasticsearch.cluster.metadata.IndexMetadata;
15+
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
16+
import org.elasticsearch.common.bytes.BytesArray;
17+
import org.elasticsearch.common.settings.Settings;
18+
import org.elasticsearch.index.Index;
19+
import org.elasticsearch.index.IndexVersion;
20+
import org.elasticsearch.index.query.MatchQueryBuilder;
21+
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
22+
import org.elasticsearch.index.query.QueryRewriteContext;
23+
import org.elasticsearch.search.builder.PointInTimeBuilder;
24+
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
25+
import org.elasticsearch.search.retriever.RetrieverBuilder;
26+
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
27+
import org.elasticsearch.test.ESTestCase;
28+
29+
import java.util.HashSet;
30+
import java.util.List;
31+
import java.util.Map;
32+
import java.util.Objects;
33+
import java.util.Set;
34+
import java.util.stream.Collectors;
35+
36+
public class LinearRetrieverBuilderTests extends ESTestCase {
37+
public void testSimplifiedParamsRewrite() {
38+
final String indexName = "test-index";
39+
final List<String> testInferenceFields = List.of("semantic_field_1", "semantic_field_2");
40+
final ResolvedIndices resolvedIndices = createMockResolvedIndices(indexName, testInferenceFields);
41+
final QueryRewriteContext queryRewriteContext = new QueryRewriteContext(
42+
parserConfig(),
43+
null,
44+
null,
45+
resolvedIndices,
46+
new PointInTimeBuilder(new BytesArray("pitid")),
47+
null
48+
);
49+
50+
// No wildcards, no per-field boosting
51+
LinearRetrieverBuilder retriever = new LinearRetrieverBuilder(
52+
null,
53+
List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"),
54+
"foo",
55+
MinMaxScoreNormalizer.INSTANCE,
56+
10,
57+
new float[0],
58+
new ScoreNormalizer[0]
59+
);
60+
assertSimplifiedParamsRewrite(
61+
retriever,
62+
queryRewriteContext,
63+
Map.of("field_1", 1.0f, "field_2", 1.0f),
64+
Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f),
65+
"foo",
66+
MinMaxScoreNormalizer.INSTANCE
67+
);
68+
69+
// No wildcards, per-field boosting
70+
retriever = new LinearRetrieverBuilder(
71+
null,
72+
List.of("field_1", "field_2^1.5", "semantic_field_1", "semantic_field_2^2"),
73+
"bar",
74+
MinMaxScoreNormalizer.INSTANCE,
75+
10,
76+
new float[0],
77+
new ScoreNormalizer[0]
78+
);
79+
assertSimplifiedParamsRewrite(
80+
retriever,
81+
queryRewriteContext,
82+
Map.of("field_1", 1.0f, "field_2", 1.5f),
83+
Map.of("semantic_field_1", 1.0f, "semantic_field_2", 2.0f),
84+
"bar",
85+
MinMaxScoreNormalizer.INSTANCE
86+
);
87+
88+
// Glob matching on inference and non-inference fields with per-field boosting
89+
retriever = new LinearRetrieverBuilder(
90+
null,
91+
List.of("field_*^1.5", "*_field_1^2.5"),
92+
"baz",
93+
MinMaxScoreNormalizer.INSTANCE,
94+
10,
95+
new float[0],
96+
new ScoreNormalizer[0]
97+
);
98+
assertSimplifiedParamsRewrite(
99+
retriever,
100+
queryRewriteContext,
101+
Map.of("field_*", 1.5f, "*_field_1", 2.5f),
102+
Map.of("semantic_field_1", 2.5f),
103+
"baz",
104+
MinMaxScoreNormalizer.INSTANCE
105+
);
106+
107+
// All-fields wildcard
108+
retriever = new LinearRetrieverBuilder(
109+
null,
110+
List.of("*"),
111+
"qux",
112+
MinMaxScoreNormalizer.INSTANCE,
113+
10,
114+
new float[0],
115+
new ScoreNormalizer[0]
116+
);
117+
assertSimplifiedParamsRewrite(
118+
retriever,
119+
queryRewriteContext,
120+
Map.of("*", 1.0f),
121+
Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f),
122+
"qux",
123+
MinMaxScoreNormalizer.INSTANCE
124+
);
125+
}
126+
127+
private static ResolvedIndices createMockResolvedIndices(String indexName, List<String> inferenceFields) {
128+
Index index = new Index(indexName, randomAlphaOfLength(10));
129+
IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder(index.getName())
130+
.settings(
131+
Settings.builder()
132+
.put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())
133+
.put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID())
134+
)
135+
.numberOfShards(1)
136+
.numberOfReplicas(0);
137+
138+
for (String inferenceField : inferenceFields) {
139+
indexMetadataBuilder.putInferenceField(
140+
new InferenceFieldMetadata(inferenceField, randomAlphaOfLengthBetween(3, 5), new String[] { inferenceField }, null)
141+
);
142+
}
143+
144+
return new MockResolvedIndices(
145+
Map.of(),
146+
new OriginalIndices(new String[] { indexName }, IndicesOptions.DEFAULT),
147+
Map.of(index, indexMetadataBuilder.build())
148+
);
149+
}
150+
151+
private static void assertSimplifiedParamsRewrite(
152+
LinearRetrieverBuilder retriever,
153+
QueryRewriteContext ctx,
154+
Map<String, Float> expectedNonInferenceFields,
155+
Map<String, Float> expectedInferenceFields,
156+
String expectedQuery,
157+
ScoreNormalizer expectedNormalizer
158+
) {
159+
Set<InnerRetriever> expectedInnerRetrievers = Set.of(
160+
new InnerRetriever(
161+
new StandardRetrieverBuilder(
162+
new MultiMatchQueryBuilder(expectedQuery).type(MultiMatchQueryBuilder.Type.MOST_FIELDS)
163+
.fields(expectedNonInferenceFields)
164+
),
165+
1.0f,
166+
expectedNormalizer
167+
),
168+
new InnerRetriever(
169+
expectedInferenceFields.entrySet()
170+
.stream()
171+
.map(
172+
e -> new InnerRetriever(
173+
new StandardRetrieverBuilder(new MatchQueryBuilder(e.getKey(), expectedQuery)),
174+
e.getValue(),
175+
expectedNormalizer
176+
)
177+
)
178+
.collect(Collectors.toSet()),
179+
1.0f,
180+
expectedNormalizer
181+
)
182+
);
183+
184+
LinearRetrieverBuilder rewritten = retriever.doRewrite(ctx);
185+
assertNotSame(retriever, rewritten);
186+
assertEquals(expectedInnerRetrievers, getInnerRetrieversAsSet(rewritten));
187+
}
188+
189+
private static Set<InnerRetriever> getInnerRetrieversAsSet(LinearRetrieverBuilder retriever) {
190+
float[] weights = retriever.getWeights();
191+
ScoreNormalizer[] normalizers = retriever.getNormalizers();
192+
193+
int i = 0;
194+
Set<InnerRetriever> innerRetrieversSet = new HashSet<>();
195+
for (CompoundRetrieverBuilder.RetrieverSource innerRetriever : retriever.innerRetrievers()) {
196+
float weight = weights[i];
197+
ScoreNormalizer normalizer = normalizers[i];
198+
199+
if (innerRetriever.retriever() instanceof LinearRetrieverBuilder innerLinearRetriever) {
200+
innerRetrieversSet.add(new InnerRetriever(getInnerRetrieversAsSet(innerLinearRetriever), weight, normalizer));
201+
} else {
202+
innerRetrieversSet.add(new InnerRetriever(innerRetriever.retriever(), weight, normalizer));
203+
}
204+
205+
i++;
206+
}
207+
208+
return innerRetrieversSet;
209+
}
210+
211+
private static class InnerRetriever {
212+
private final Object retriever;
213+
private final float weight;
214+
private final ScoreNormalizer normalizer;
215+
216+
InnerRetriever(RetrieverBuilder retriever, float weight, ScoreNormalizer normalizer) {
217+
this.retriever = retriever;
218+
this.weight = weight;
219+
this.normalizer = normalizer;
220+
}
221+
222+
InnerRetriever(Set<InnerRetriever> innerRetrievers, float weight, ScoreNormalizer normalizer) {
223+
this.retriever = innerRetrievers;
224+
this.weight = weight;
225+
this.normalizer = normalizer;
226+
}
227+
228+
@Override
229+
public boolean equals(Object o) {
230+
if (this == o) return true;
231+
if (o == null || getClass() != o.getClass()) return false;
232+
InnerRetriever that = (InnerRetriever) o;
233+
return Float.compare(weight, that.weight) == 0
234+
&& Objects.equals(retriever, that.retriever)
235+
&& Objects.equals(normalizer, that.normalizer);
236+
}
237+
238+
@Override
239+
public int hashCode() {
240+
return Objects.hash(retriever, weight, normalizer);
241+
}
242+
}
243+
}

0 commit comments

Comments
 (0)