Skip to content

Commit 1703f1f

Browse files
committed
_almost_ working version of rewrite tests
1 parent 397a1b8 commit 1703f1f

File tree

3 files changed

+115
-2
lines changed

3 files changed

+115
-2
lines changed

x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@
5454
import static org.hamcrest.Matchers.instanceOf;
5555
import static org.hamcrest.Matchers.lessThanOrEqualTo;
5656

57-
// TODO: Add simplified format tests
58-
5957
@ESIntegTestCase.ClusterScope(minNumDataNodes = 3)
6058
public class RRFRetrieverBuilderIT extends ESIntegTestCase {
6159

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ public RRFRetrieverBuilder(int rankWindowSize, int rankConstant) {
117117
// Otherwise some of the validation is skipped when creating the retriever programmatically.
118118
}
119119

120+
public int rankConstant() {
121+
return rankConstant;
122+
}
123+
120124
@Override
121125
public String getName() {
122126
return NAME;

x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,26 @@
77

88
package org.elasticsearch.xpack.rank.rrf;
99

10+
import org.elasticsearch.action.MockResolvedIndices;
11+
import org.elasticsearch.action.OriginalIndices;
12+
import org.elasticsearch.action.ResolvedIndices;
13+
import org.elasticsearch.action.support.IndicesOptions;
14+
import org.elasticsearch.cluster.metadata.IndexMetadata;
15+
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
1016
import org.elasticsearch.common.bytes.BytesArray;
1117
import org.elasticsearch.common.settings.Settings;
18+
import org.elasticsearch.index.Index;
19+
import org.elasticsearch.index.IndexVersion;
20+
import org.elasticsearch.index.query.MatchQueryBuilder;
21+
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
1222
import org.elasticsearch.index.query.QueryRewriteContext;
1323
import org.elasticsearch.search.SearchModule;
1424
import org.elasticsearch.search.builder.PointInTimeBuilder;
1525
import org.elasticsearch.search.builder.SearchSourceBuilder;
26+
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
1627
import org.elasticsearch.search.retriever.RetrieverBuilder;
1728
import org.elasticsearch.search.retriever.RetrieverParserContext;
29+
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
1830
import org.elasticsearch.test.ESTestCase;
1931
import org.elasticsearch.xcontent.NamedXContentRegistry;
2032
import org.elasticsearch.xcontent.ParseField;
@@ -23,6 +35,9 @@
2335

2436
import java.io.IOException;
2537
import java.util.List;
38+
import java.util.Map;
39+
40+
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.convertToRetrieverSource;
2641

2742
/** Tests for the rrf retriever. */
2843
public class RRFRetrieverBuilderTests extends ESTestCase {
@@ -66,6 +81,37 @@ public void testRetrieverExtractionErrors() throws IOException {
6681
}
6782
}
6883

84+
public void testSimplifiedParamsRewrite() {
85+
final String indexName = "test-index";
86+
final List<String> testInferenceFields = List.of("semantic_field_1", "semantic_field_2");
87+
final ResolvedIndices resolvedIndices = createMockResolvedIndices(indexName, testInferenceFields);
88+
final QueryRewriteContext queryRewriteContext = new QueryRewriteContext(
89+
parserConfig(),
90+
null,
91+
null,
92+
resolvedIndices,
93+
new PointInTimeBuilder(new BytesArray("pitid")),
94+
null
95+
);
96+
97+
RRFRetrieverBuilder rrfRetrieverBuilder = new RRFRetrieverBuilder(
98+
null,
99+
List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"),
100+
"foo",
101+
10,
102+
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT
103+
);
104+
assertSimplifiedParamsRewrite(
105+
rrfRetrieverBuilder,
106+
queryRewriteContext,
107+
Map.of("field_1", 1.0f, "field_2", 1.0f),
108+
Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f),
109+
"foo"
110+
);
111+
112+
// TODO: Test with wildcard resolution
113+
}
114+
69115
@Override
70116
protected NamedXContentRegistry xContentRegistry() {
71117
List<NamedXContentRegistry.Entry> entries = new SearchModule(Settings.EMPTY, List.of()).getNamedXContents();
@@ -86,4 +132,69 @@ protected NamedXContentRegistry xContentRegistry() {
86132
);
87133
return new NamedXContentRegistry(entries);
88134
}
135+
136+
private static ResolvedIndices createMockResolvedIndices(String indexName, List<String> inferenceFields) {
137+
Index index = new Index(indexName, randomAlphaOfLength(10));
138+
IndexMetadata.Builder indexMetadataBuilder = IndexMetadata.builder(index.getName())
139+
.settings(
140+
Settings.builder()
141+
.put(IndexMetadata.SETTING_VERSION_CREATED, IndexVersion.current())
142+
.put(IndexMetadata.SETTING_INDEX_UUID, index.getUUID())
143+
)
144+
.numberOfShards(1)
145+
.numberOfReplicas(0);
146+
147+
for (String inferenceField : inferenceFields) {
148+
indexMetadataBuilder.putInferenceField(
149+
new InferenceFieldMetadata(inferenceField, randomAlphaOfLengthBetween(3, 5), new String[] { inferenceField }, null)
150+
);
151+
}
152+
153+
return new MockResolvedIndices(
154+
Map.of(),
155+
new OriginalIndices(new String[] { indexName }, IndicesOptions.DEFAULT),
156+
Map.of(index, indexMetadataBuilder.build())
157+
);
158+
}
159+
160+
private static void assertSimplifiedParamsRewrite(
161+
RRFRetrieverBuilder retriever,
162+
QueryRewriteContext ctx,
163+
Map<String, Float> expectedNonInferenceFields,
164+
Map<String, Float> expectedInferenceFields,
165+
String expectedQuery
166+
) {
167+
List<CompoundRetrieverBuilder.RetrieverSource> expectedInnerRetrievers = List.of(
168+
convertToRetrieverSource(
169+
new StandardRetrieverBuilder(
170+
new MultiMatchQueryBuilder(expectedQuery).type(MultiMatchQueryBuilder.Type.MOST_FIELDS)
171+
.fields(expectedNonInferenceFields)
172+
)
173+
),
174+
convertToRetrieverSource(
175+
new RRFRetrieverBuilder(
176+
expectedInferenceFields.entrySet()
177+
.stream()
178+
.map(e -> {
179+
if (e.getValue() != 1.0f) {
180+
throw new IllegalArgumentException("Cannot apply per-field weights in RRF");
181+
}
182+
return convertToRetrieverSource(new StandardRetrieverBuilder(new MatchQueryBuilder(e.getKey(), expectedQuery)));
183+
})
184+
.toList(),
185+
retriever.rankWindowSize(),
186+
retriever.rankConstant()
187+
)
188+
)
189+
);
190+
RRFRetrieverBuilder expectedRewritten = new RRFRetrieverBuilder(
191+
expectedInnerRetrievers,
192+
retriever.rankWindowSize(),
193+
retriever.rankConstant()
194+
);
195+
196+
RRFRetrieverBuilder rewritten = retriever.doRewrite(ctx);
197+
assertNotSame(retriever, rewritten);
198+
assertEquals(expectedRewritten, rewritten);
199+
}
89200
}

0 commit comments

Comments
 (0)