Skip to content

Commit 6da6bb2

Browse files
committed
Simplified linear retriever
1 parent 087747b commit 6da6bb2

File tree

10 files changed

+1292
-17
lines changed

10 files changed

+1292
-17
lines changed

server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import java.io.IOException;
3838
import java.util.ArrayList;
39+
import java.util.Collections;
3940
import java.util.List;
4041
import java.util.Locale;
4142
import java.util.Objects;
@@ -65,7 +66,7 @@ protected CompoundRetrieverBuilder(List<RetrieverSource> innerRetrievers, int ra
6566

6667
@SuppressWarnings("unchecked")
6768
public T addChild(RetrieverBuilder retrieverBuilder) {
68-
innerRetrievers.add(new RetrieverSource(retrieverBuilder, null));
69+
innerRetrievers.add(convertToRetrieverSource(retrieverBuilder));
6970
return (T) this;
7071
}
7172

@@ -99,6 +100,11 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
99100
throw new IllegalStateException("PIT is required");
100101
}
101102

103+
RetrieverBuilder rewritten = doRewrite(ctx);
104+
if (rewritten != this) {
105+
return rewritten;
106+
}
107+
102108
// Rewrite prefilters
103109
// We eagerly rewrite prefilters, because some of the innerRetrievers
104110
// could be compound too, so we want to propagate all the necessary filter information to them
@@ -290,6 +296,14 @@ public int rankWindowSize() {
290296
return rankWindowSize;
291297
}
292298

299+
public List<RetrieverSource> innerRetrievers() {
300+
return Collections.unmodifiableList(innerRetrievers);
301+
}
302+
303+
public static RetrieverSource convertToRetrieverSource(RetrieverBuilder retrieverBuilder) {
304+
return new RetrieverSource(retrieverBuilder, null);
305+
}
306+
293307
protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
294308
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
295309
.trackTotalHits(false)
@@ -316,6 +330,16 @@ protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBu
316330
return sourceBuilder;
317331
}
318332

333+
/**
334+
* Perform any custom rewrite logic necessary
335+
*
336+
* @param ctx The query rewrite context
337+
* @return RetrieverBuilder the rewritten retriever
338+
*/
339+
protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
340+
return this;
341+
}
342+
319343
private RankDoc[] getRankDocs(SearchResponse searchResponse) {
320344
int size = searchResponse.getHits().getHits().length;
321345
RankDoc[] docs = new RankDoc[size];

x-pack/plugin/inference/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,5 +403,5 @@ tasks.named("thirdPartyAudit").configure {
403403
}
404404

405405
tasks.named('yamlRestTest') {
406-
usesDefaultDistribution("to be triaged")
406+
usesDefaultDistribution("Uses the inference API")
407407
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.queries;
9+
10+
import org.apache.lucene.index.Term;
11+
import org.apache.lucene.search.DisjunctionMaxQuery;
12+
import org.apache.lucene.search.Query;
13+
import org.apache.lucene.search.TermQuery;
14+
import org.elasticsearch.cluster.ClusterChangedEvent;
15+
import org.elasticsearch.common.settings.Settings;
16+
import org.elasticsearch.core.IOUtils;
17+
import org.elasticsearch.index.mapper.MapperService;
18+
import org.elasticsearch.index.mapper.MapperServiceTestCase;
19+
import org.elasticsearch.index.mapper.ParsedDocument;
20+
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
21+
import org.elasticsearch.index.query.SearchExecutionContext;
22+
import org.elasticsearch.plugins.Plugin;
23+
import org.elasticsearch.test.ClusterServiceUtils;
24+
import org.elasticsearch.test.client.NoOpClient;
25+
import org.elasticsearch.threadpool.TestThreadPool;
26+
import org.elasticsearch.xpack.inference.InferencePlugin;
27+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
28+
import org.junit.AfterClass;
29+
import org.junit.BeforeClass;
30+
31+
import java.util.Collection;
32+
import java.util.List;
33+
import java.util.function.Supplier;
34+
35+
public class SemanticMultiMatchQueryBuilderTests extends MapperServiceTestCase {
36+
private static TestThreadPool threadPool;
37+
private static ModelRegistry modelRegistry;
38+
39+
private static class InferencePluginWithModelRegistry extends InferencePlugin {
40+
InferencePluginWithModelRegistry(Settings settings) {
41+
super(settings);
42+
}
43+
44+
@Override
45+
protected Supplier<ModelRegistry> getModelRegistry() {
46+
return () -> modelRegistry;
47+
}
48+
}
49+
50+
@BeforeClass
51+
public static void startModelRegistry() {
52+
threadPool = new TestThreadPool(SemanticMultiMatchQueryBuilderTests.class.getName());
53+
var clusterService = ClusterServiceUtils.createClusterService(threadPool);
54+
modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool));
55+
modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) {
56+
@Override
57+
public boolean localNodeMaster() {
58+
return false;
59+
}
60+
});
61+
}
62+
63+
@AfterClass
64+
public static void stopModelRegistry() {
65+
IOUtils.closeWhileHandlingException(threadPool);
66+
}
67+
68+
@Override
69+
protected Collection<? extends Plugin> getPlugins() {
70+
return List.of(new InferencePluginWithModelRegistry(Settings.EMPTY));
71+
}
72+
73+
public void testResolveSemanticTextFieldFromWildcard() throws Exception {
74+
MapperService mapperService = createMapperService("""
75+
{
76+
"_doc" : {
77+
"properties": {
78+
"text_field": { "type": "text" },
79+
"keyword_field": { "type": "keyword" },
80+
"inference_field": { "type": "semantic_text", "inference_id": "test_service" }
81+
}
82+
}
83+
}
84+
""");
85+
86+
ParsedDocument doc = mapperService.documentMapper().parse(source("""
87+
{
88+
"text_field" : "foo",
89+
"keyword_field" : "foo",
90+
"inference_field" : "foo",
91+
"_inference_fields": {
92+
"inference_field": {
93+
"inference": {
94+
"inference_id": "test_service",
95+
"model_settings": {
96+
"task_type": "sparse_embedding"
97+
},
98+
"chunks": {
99+
"inference_field": [
100+
{
101+
"start_offset": 0,
102+
"end_offset": 3,
103+
"embeddings": {
104+
"foo": 1.0
105+
}
106+
}
107+
]
108+
}
109+
}
110+
}
111+
}
112+
}
113+
"""));
114+
115+
withLuceneIndex(mapperService, iw -> iw.addDocument(doc.rootDoc()), ir -> {
116+
SearchExecutionContext context = createSearchExecutionContext(mapperService, newSearcher(ir));
117+
Query query = new MultiMatchQueryBuilder("foo", "*_field").toQuery(context);
118+
Query expected = new DisjunctionMaxQuery(
119+
List.of(new TermQuery(new Term("text_field", "foo")), new TermQuery(new Term("keyword_field", "foo"))),
120+
0f
121+
);
122+
assertEquals(expected, query);
123+
});
124+
}
125+
}

x-pack/plugin/rank-rrf/build.gradle

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,7 @@ dependencies {
3030

3131
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
3232
}
33+
34+
tasks.named('yamlRestTest') {
35+
usesDefaultDistribution("Uses the inference API")
36+
}

0 commit comments

Comments
 (0)