Skip to content

Commit fa2b53b

Browse files
authored
Simplified Linear Retriever (#129200) (#129563)
(cherry picked from commit fc77640) # Conflicts: # x-pack/plugin/inference/build.gradle # x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java # x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java
1 parent 6ae9ada commit fa2b53b

File tree

15 files changed

+1427
-24
lines changed

15 files changed

+1427
-24
lines changed

docs/changelog/129200.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129200
2+
summary: Simplified Linear Retriever
3+
area: Search
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
import java.io.IOException;
3838
import java.util.ArrayList;
39+
import java.util.Collections;
3940
import java.util.List;
4041
import java.util.Locale;
4142
import java.util.Objects;
@@ -53,7 +54,11 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
5354

5455
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
5556

56-
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}
57+
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {
58+
public static RetrieverSource from(RetrieverBuilder retriever) {
59+
return new RetrieverSource(retriever, null);
60+
}
61+
}
5762

5863
protected final int rankWindowSize;
5964
protected final List<RetrieverSource> innerRetrievers;
@@ -65,7 +70,7 @@ protected CompoundRetrieverBuilder(List<RetrieverSource> innerRetrievers, int ra
6570

6671
@SuppressWarnings("unchecked")
6772
public T addChild(RetrieverBuilder retrieverBuilder) {
68-
innerRetrievers.add(new RetrieverSource(retrieverBuilder, null));
73+
innerRetrievers.add(RetrieverSource.from(retrieverBuilder));
6974
return (T) this;
7075
}
7176

@@ -99,6 +104,11 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
99104
throw new IllegalStateException("PIT is required");
100105
}
101106

107+
RetrieverBuilder rewritten = doRewrite(ctx);
108+
if (rewritten != this) {
109+
return rewritten;
110+
}
111+
102112
// Rewrite prefilters
103113
// We eagerly rewrite prefilters, because some of the innerRetrievers
104114
// could be compound too, so we want to propagate all the necessary filter information to them
@@ -121,7 +131,7 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
121131
}
122132
RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx);
123133
if (newRetriever != entry.retriever) {
124-
newRetrievers.add(new RetrieverSource(newRetriever, null));
134+
newRetrievers.add(RetrieverSource.from(newRetriever));
125135
hasChanged |= true;
126136
} else {
127137
var sourceBuilder = entry.source != null
@@ -291,6 +301,10 @@ public int rankWindowSize() {
291301
return rankWindowSize;
292302
}
293303

304+
public List<RetrieverSource> innerRetrievers() {
305+
return Collections.unmodifiableList(innerRetrievers);
306+
}
307+
294308
protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
295309
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
296310
.trackTotalHits(false)
@@ -317,6 +331,16 @@ protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBu
317331
return sourceBuilder;
318332
}
319333

334+
/**
335+
* Perform any custom rewrite logic necessary
336+
*
337+
* @param ctx The query rewrite context
338+
* @return RetrieverBuilder the rewritten retriever
339+
*/
340+
protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
341+
return this;
342+
}
343+
320344
private RankDoc[] getRankDocs(SearchResponse searchResponse) {
321345
int size = searchResponse.getHits().getHits().length;
322346
RankDoc[] docs = new RankDoc[size];

server/src/main/java/org/elasticsearch/search/retriever/RescorerRetrieverBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public static RescorerRetrieverBuilder fromXContent(XContentParser parser, Retri
7878
private final List<RescorerBuilder<?>> rescorers;
7979

8080
public RescorerRetrieverBuilder(RetrieverBuilder retriever, List<RescorerBuilder<?>> rescorers) {
81-
super(List.of(new RetrieverSource(retriever, null)), extractMinWindowSize(rescorers));
81+
super(List.of(RetrieverSource.from(retriever)), extractMinWindowSize(rescorers));
8282
if (rescorers.isEmpty()) {
8383
throw new IllegalArgumentException("Missing rescore definition");
8484
}

x-pack/plugin/inference/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,5 +404,5 @@ tasks.named("thirdPartyAudit").configure {
404404
}
405405

406406
tasks.named('yamlRestTest') {
407-
usesDefaultDistribution()
407+
usesDefaultDistribution("Uses the inference API")
408408
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public TextSimilarityRankRetrieverBuilder(
124124
int rankWindowSize,
125125
boolean failuresAllowed
126126
) {
127-
super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize);
127+
super(List.of(RetrieverSource.from(retrieverBuilder)), rankWindowSize);
128128
this.inferenceId = inferenceId;
129129
this.inferenceText = inferenceText;
130130
this.field = field;
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.queries;
9+
10+
import org.apache.lucene.index.Term;
11+
import org.apache.lucene.search.DisjunctionMaxQuery;
12+
import org.apache.lucene.search.Query;
13+
import org.apache.lucene.search.TermQuery;
14+
import org.elasticsearch.cluster.ClusterChangedEvent;
15+
import org.elasticsearch.common.settings.Settings;
16+
import org.elasticsearch.core.IOUtils;
17+
import org.elasticsearch.index.mapper.MapperService;
18+
import org.elasticsearch.index.mapper.MapperServiceTestCase;
19+
import org.elasticsearch.index.mapper.ParsedDocument;
20+
import org.elasticsearch.index.query.MultiMatchQueryBuilder;
21+
import org.elasticsearch.index.query.SearchExecutionContext;
22+
import org.elasticsearch.plugins.Plugin;
23+
import org.elasticsearch.test.ClusterServiceUtils;
24+
import org.elasticsearch.test.client.NoOpClient;
25+
import org.elasticsearch.threadpool.TestThreadPool;
26+
import org.elasticsearch.xpack.inference.InferencePlugin;
27+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
28+
import org.junit.AfterClass;
29+
import org.junit.BeforeClass;
30+
31+
import java.util.Collection;
32+
import java.util.List;
33+
import java.util.function.Supplier;
34+
35+
public class SemanticMultiMatchQueryBuilderTests extends MapperServiceTestCase {
36+
private static TestThreadPool threadPool;
37+
private static ModelRegistry modelRegistry;
38+
39+
private static class InferencePluginWithModelRegistry extends InferencePlugin {
40+
InferencePluginWithModelRegistry(Settings settings) {
41+
super(settings);
42+
}
43+
44+
@Override
45+
protected Supplier<ModelRegistry> getModelRegistry() {
46+
return () -> modelRegistry;
47+
}
48+
}
49+
50+
@BeforeClass
51+
public static void startModelRegistry() {
52+
threadPool = new TestThreadPool(SemanticMultiMatchQueryBuilderTests.class.getName());
53+
var clusterService = ClusterServiceUtils.createClusterService(threadPool);
54+
modelRegistry = new ModelRegistry(clusterService, new NoOpClient(threadPool));
55+
modelRegistry.clusterChanged(new ClusterChangedEvent("init", clusterService.state(), clusterService.state()) {
56+
@Override
57+
public boolean localNodeMaster() {
58+
return false;
59+
}
60+
});
61+
}
62+
63+
@AfterClass
64+
public static void stopModelRegistry() {
65+
IOUtils.closeWhileHandlingException(threadPool);
66+
}
67+
68+
@Override
69+
protected Collection<? extends Plugin> getPlugins() {
70+
return List.of(new InferencePluginWithModelRegistry(Settings.EMPTY));
71+
}
72+
73+
public void testResolveSemanticTextFieldFromWildcard() throws Exception {
74+
MapperService mapperService = createMapperService("""
75+
{
76+
"_doc" : {
77+
"properties": {
78+
"text_field": { "type": "text" },
79+
"keyword_field": { "type": "keyword" },
80+
"inference_field": { "type": "semantic_text", "inference_id": "test_service" }
81+
}
82+
}
83+
}
84+
""");
85+
86+
ParsedDocument doc = mapperService.documentMapper().parse(source("""
87+
{
88+
"text_field" : "foo",
89+
"keyword_field" : "foo",
90+
"inference_field" : "foo",
91+
"_inference_fields": {
92+
"inference_field": {
93+
"inference": {
94+
"inference_id": "test_service",
95+
"model_settings": {
96+
"task_type": "sparse_embedding"
97+
},
98+
"chunks": {
99+
"inference_field": [
100+
{
101+
"start_offset": 0,
102+
"end_offset": 3,
103+
"embeddings": {
104+
"foo": 1.0
105+
}
106+
}
107+
]
108+
}
109+
}
110+
}
111+
}
112+
}
113+
"""));
114+
115+
withLuceneIndex(mapperService, iw -> iw.addDocument(doc.rootDoc()), ir -> {
116+
SearchExecutionContext context = createSearchExecutionContext(mapperService, newSearcher(ir));
117+
Query query = new MultiMatchQueryBuilder("foo", "*_field").toQuery(context);
118+
Query expected = new DisjunctionMaxQuery(
119+
List.of(new TermQuery(new Term("text_field", "foo")), new TermQuery(new Term("keyword_field", "foo"))),
120+
0f
121+
);
122+
assertEquals(expected, query);
123+
});
124+
}
125+
}

x-pack/plugin/rank-rrf/build.gradle

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,7 @@ dependencies {
3030

3131
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
3232
}
33+
34+
tasks.named('yamlRestTest') {
35+
usesDefaultDistribution("Uses the inference API")
36+
}

0 commit comments

Comments
 (0)