Skip to content

Commit 3e57a57

Browse files
authored
[8.x] Fix for propagating filters from compound to inner retrievers (#117914) (#118046)
* Fix for propagating filters from compound to inner retrievers (#117914) * Update RRFRetrieverBuilderIT.java
1 parent 36d8307 commit 3e57a57

File tree

13 files changed

+180
-48
lines changed

13 files changed

+180
-48
lines changed

docs/changelog/117914.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 117914
2+
summary: Fix for propagating filters from compound to inner retrievers
3+
area: Ranking
4+
type: bug
5+
issues: []

server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.action.search.SearchRequest;
2121
import org.elasticsearch.action.search.SearchResponse;
2222
import org.elasticsearch.action.search.TransportMultiSearchAction;
23+
import org.elasticsearch.features.NodeFeature;
2324
import org.elasticsearch.index.query.QueryBuilder;
2425
import org.elasticsearch.index.query.QueryRewriteContext;
2526
import org.elasticsearch.rest.RestStatus;
@@ -46,6 +47,8 @@
4647
*/
4748
public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilder<T>> extends RetrieverBuilder {
4849

50+
public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support");
51+
4952
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}
5053

5154
protected final int rankWindowSize;
@@ -64,9 +67,9 @@ public T addChild(RetrieverBuilder retrieverBuilder) {
6467

6568
/**
6669
* Returns a clone of the original retriever, replacing the sub-retrievers with
67-
* the provided {@code newChildRetrievers}.
70+
* the provided {@code newChildRetrievers} and the filters with the {@code newPreFilterQueryBuilders}.
6871
*/
69-
protected abstract T clone(List<RetrieverSource> newChildRetrievers);
72+
protected abstract T clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders);
7073

7174
/**
7275
* Combines the provided {@code rankResults} to return the final top documents.
@@ -85,13 +88,25 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
8588
}
8689

8790
// Rewrite prefilters
88-
boolean hasChanged = false;
91+
// We eagerly rewrite prefilters, because some of the innerRetrievers
92+
// could be compound too, so we want to propagate all the necessary filter information to them
93+
// and have it available as part of their own rewrite step
8994
var newPreFilters = rewritePreFilters(ctx);
90-
hasChanged |= newPreFilters != preFilterQueryBuilders;
95+
if (newPreFilters != preFilterQueryBuilders) {
96+
return clone(innerRetrievers, newPreFilters);
97+
}
9198

99+
boolean hasChanged = false;
92100
// Rewrite retriever sources
93101
List<RetrieverSource> newRetrievers = new ArrayList<>();
94102
for (var entry : innerRetrievers) {
103+
// we propagate the filters only for compound retrievers as they won't be attached through
104+
// the createSearchSourceBuilder.
105+
// We could remove this check, but we would end up adding the same filters
106+
// multiple times in case an inner retriever rewrites itself, when we re-enter to rewrite
107+
if (entry.retriever.isCompound() && false == preFilterQueryBuilders.isEmpty()) {
108+
entry.retriever.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
109+
}
95110
RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx);
96111
if (newRetriever != entry.retriever) {
97112
newRetrievers.add(new RetrieverSource(newRetriever, null));
@@ -106,7 +121,7 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
106121
}
107122
}
108123
if (hasChanged) {
109-
return clone(newRetrievers);
124+
return clone(newRetrievers, newPreFilters);
110125
}
111126

112127
// execute searches
@@ -166,12 +181,7 @@ public void onFailure(Exception e) {
166181
});
167182
});
168183

169-
return new RankDocsRetrieverBuilder(
170-
rankWindowSize,
171-
newRetrievers.stream().map(s -> s.retriever).toList(),
172-
results::get,
173-
newPreFilters
174-
);
184+
return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get);
175185
}
176186

177187
@Override

server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,7 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
184184
ll.onResponse(null);
185185
}));
186186
});
187-
var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null);
188-
return rewritten;
187+
return new KnnRetrieverBuilder(this, () -> toSet.get(), null);
189188
}
190189
return super.rewrite(ctx);
191190
}

server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,13 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
3333
final List<RetrieverBuilder> sources;
3434
final Supplier<RankDoc[]> rankDocs;
3535

36-
public RankDocsRetrieverBuilder(
37-
int rankWindowSize,
38-
List<RetrieverBuilder> sources,
39-
Supplier<RankDoc[]> rankDocs,
40-
List<QueryBuilder> preFilterQueryBuilders
41-
) {
36+
public RankDocsRetrieverBuilder(int rankWindowSize, List<RetrieverBuilder> sources, Supplier<RankDoc[]> rankDocs) {
4237
this.rankWindowSize = rankWindowSize;
4338
this.rankDocs = rankDocs;
4439
if (sources == null || sources.isEmpty()) {
4540
throw new IllegalArgumentException("sources must not be null or empty");
4641
}
4742
this.sources = sources;
48-
this.preFilterQueryBuilders = preFilterQueryBuilders;
4943
}
5044

5145
@Override
@@ -73,10 +67,6 @@ private boolean sourceShouldRewrite(QueryRewriteContext ctx) throws IOException
7367
@Override
7468
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
7569
assert false == sourceShouldRewrite(ctx) : "retriever sources should be rewritten first";
76-
var rewrittenFilters = rewritePreFilters(ctx);
77-
if (rewrittenFilters != preFilterQueryBuilders) {
78-
return new RankDocsRetrieverBuilder(rankWindowSize, sources, rankDocs, rewrittenFilters);
79-
}
8070
return this;
8171
}
8272

@@ -94,7 +84,7 @@ public QueryBuilder topDocsQuery() {
9484
boolQuery.should(query);
9585
}
9686
}
97-
// ignore prefilters of this level, they are already propagated to children
87+
// ignore prefilters of this level, they were already propagated to children
9888
return boolQuery;
9989
}
10090

@@ -133,7 +123,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
133123
} else {
134124
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
135125
}
136-
// ignore prefilters of this level, they are already propagated to children
126+
// ignore prefilters of this level, they were already propagated to children
137127
searchSourceBuilder.query(rankQuery);
138128
if (sourceHasMinScore()) {
139129
searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());

server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,7 @@ private List<QueryBuilder> preFilters(QueryRewriteContext queryRewriteContext) t
9595
}
9696

9797
private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder(QueryRewriteContext queryRewriteContext) throws IOException {
98-
return new RankDocsRetrieverBuilder(
99-
randomIntBetween(1, 100),
100-
innerRetrievers(queryRewriteContext),
101-
rankDocsSupplier(),
102-
preFilters(queryRewriteContext)
103-
);
98+
return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(queryRewriteContext), rankDocsSupplier());
10499
}
105100

106101
public void testExtractToSearchSourceBuilder() throws IOException {

server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
/**
2828
* A SearchPlugin to exercise query vector builder
2929
*/
30-
class TestQueryVectorBuilderPlugin implements SearchPlugin {
30+
public class TestQueryVectorBuilderPlugin implements SearchPlugin {
3131

32-
static class TestQueryVectorBuilder implements QueryVectorBuilder {
32+
public static class TestQueryVectorBuilder implements QueryVectorBuilder {
3333
private static final String NAME = "test_query_vector_builder";
3434

3535
private static final ParseField QUERY_VECTOR = new ParseField("query_vector");
@@ -47,11 +47,11 @@ static class TestQueryVectorBuilder implements QueryVectorBuilder {
4747

4848
private List<Float> vectorToBuild;
4949

50-
TestQueryVectorBuilder(List<Float> vectorToBuild) {
50+
public TestQueryVectorBuilder(List<Float> vectorToBuild) {
5151
this.vectorToBuild = vectorToBuild;
5252
}
5353

54-
TestQueryVectorBuilder(float[] expected) {
54+
public TestQueryVectorBuilder(float[] expected) {
5555
this.vectorToBuild = new ArrayList<>(expected.length);
5656
for (float f : expected) {
5757
vectorToBuild.add(f);

test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.search.retriever;
1111

1212
import org.apache.lucene.search.ScoreDoc;
13+
import org.elasticsearch.index.query.QueryBuilder;
1314
import org.elasticsearch.search.rank.RankDoc;
1415
import org.elasticsearch.xcontent.XContentBuilder;
1516

@@ -23,16 +24,17 @@ public class TestCompoundRetrieverBuilder extends CompoundRetrieverBuilder<TestC
2324
public static final String NAME = "test_compound_retriever_builder";
2425

2526
public TestCompoundRetrieverBuilder(int rankWindowSize) {
26-
this(new ArrayList<>(), rankWindowSize);
27+
this(new ArrayList<>(), rankWindowSize, new ArrayList<>());
2728
}
2829

29-
TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize) {
30+
TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, List<QueryBuilder> preFilterQueryBuilders) {
3031
super(childRetrievers, rankWindowSize);
32+
this.preFilterQueryBuilders = preFilterQueryBuilders;
3133
}
3234

3335
@Override
34-
protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
35-
return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize);
36+
protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
37+
return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize, newPreFilterQueryBuilders);
3638
}
3739

3840
@Override

x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,14 @@ public QueryRuleRetrieverBuilder(
110110
Map<String, Object> matchCriteria,
111111
List<RetrieverSource> retrieverSource,
112112
int rankWindowSize,
113-
String retrieverName
113+
String retrieverName,
114+
List<QueryBuilder> preFilterQueryBuilders
114115
) {
115116
super(retrieverSource, rankWindowSize);
116117
this.rulesetIds = rulesetIds;
117118
this.matchCriteria = matchCriteria;
118119
this.retrieverName = retrieverName;
120+
this.preFilterQueryBuilders = preFilterQueryBuilders;
119121
}
120122

121123
@Override
@@ -156,8 +158,15 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
156158
}
157159

158160
@Override
159-
protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
160-
return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, newChildRetrievers, rankWindowSize, retrieverName);
161+
protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
162+
return new QueryRuleRetrieverBuilder(
163+
rulesetIds,
164+
matchCriteria,
165+
newChildRetrievers,
166+
rankWindowSize,
167+
retrieverName,
168+
newPreFilterQueryBuilders
169+
);
161170
}
162171

163172
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ public TextSimilarityRankRetrieverBuilder(
129129
}
130130

131131
@Override
132-
protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
132+
protected TextSimilarityRankRetrieverBuilder clone(
133+
List<RetrieverSource> newChildRetrievers,
134+
List<QueryBuilder> newPreFilterQueryBuilders
135+
) {
133136
return new TextSimilarityRankRetrieverBuilder(
134137
newChildRetrievers,
135138
inferenceId,
@@ -138,7 +141,7 @@ protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChil
138141
rankWindowSize,
139142
minScore,
140143
retrieverName,
141-
preFilterQueryBuilders
144+
newPreFilterQueryBuilders
142145
);
143146
}
144147

x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.elasticsearch.search.sort.SortOrder;
3434
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
3535
import org.elasticsearch.search.vectors.QueryVectorBuilder;
36+
import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin;
3637
import org.elasticsearch.test.ESIntegTestCase;
3738
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
3839
import org.elasticsearch.xcontent.XContentBuilder;
@@ -57,7 +58,6 @@
5758
public class RRFRetrieverBuilderIT extends ESIntegTestCase {
5859

5960
protected static String INDEX = "test_index";
60-
protected static final String ID_FIELD = "_id";
6161
protected static final String DOC_FIELD = "doc";
6262
protected static final String TEXT_FIELD = "text";
6363
protected static final String VECTOR_FIELD = "vector";
@@ -743,6 +743,42 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
743743
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
744744
}
745745

746+
public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() {
747+
final int rankWindowSize = 100;
748+
final int rankConstant = 10;
749+
SearchSourceBuilder source = new SearchSourceBuilder();
750+
// this will retriever all but 7 only due to top-level filter
751+
StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery());
752+
// this will too retrieve just doc 7
753+
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
754+
"vector",
755+
null,
756+
new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }),
757+
10,
758+
10,
759+
null
760+
);
761+
source.retriever(
762+
new RRFRetrieverBuilder(
763+
Arrays.asList(
764+
new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null),
765+
new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null)
766+
),
767+
rankWindowSize,
768+
rankConstant
769+
)
770+
);
771+
source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7")));
772+
source.size(10);
773+
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
774+
ElasticsearchAssertions.assertResponse(req, resp -> {
775+
assertNull(resp.pointInTimeId());
776+
assertNotNull(resp.getHits().getTotalHits());
777+
assertThat(resp.getHits().getTotalHits().value, equalTo(1L));
778+
assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7"));
779+
});
780+
}
781+
746782
public void testRewriteOnce() {
747783
final float[] vector = new float[] { 1 };
748784
AtomicInteger numAsyncCalls = new AtomicInteger();

0 commit comments

Comments
 (0)