Skip to content

Commit e52a6f2

Browse files
[8.16] Fix for propagating filters from compound to inner retrievers (#117914) (#118047)
* Fix for propagating filters from compound to inner retrievers * fix for lucene 9 * Update CompoundRetrieverBuilder.java * Update CompoundRetrieverBuilder.java * Update CompoundRetrieverBuilder.java --------- Co-authored-by: Elastic Machine <[email protected]>
1 parent 0e0d624 commit e52a6f2

File tree

12 files changed

+168
-45
lines changed

12 files changed

+168
-45
lines changed

docs/changelog/117914.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 117914
2+
summary: Fix for propagating filters from compound to inner retrievers
3+
area: Ranking
4+
type: bug
5+
issues: []

server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.action.search.SearchRequest;
2121
import org.elasticsearch.action.search.SearchResponse;
2222
import org.elasticsearch.action.search.TransportMultiSearchAction;
23+
import org.elasticsearch.features.NodeFeature;
2324
import org.elasticsearch.index.query.BoolQueryBuilder;
2425
import org.elasticsearch.index.query.QueryBuilder;
2526
import org.elasticsearch.index.query.QueryRewriteContext;
@@ -47,6 +48,8 @@
4748
*/
4849
public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilder<T>> extends RetrieverBuilder {
4950

51+
public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support");
52+
5053
public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}
5154

5255
protected final int rankWindowSize;
@@ -65,9 +68,9 @@ public T addChild(RetrieverBuilder retrieverBuilder) {
6568

6669
/**
6770
* Returns a clone of the original retriever, replacing the sub-retrievers with
68-
* the provided {@code newChildRetrievers}.
71+
* the provided {@code newChildRetrievers} and the filters with the {@code newPreFilterQueryBuilders}.
6972
*/
70-
protected abstract T clone(List<RetrieverSource> newChildRetrievers);
73+
protected abstract T clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders);
7174

7275
/**
7376
* Combines the provided {@code rankResults} to return the final top documents.
@@ -86,13 +89,25 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
8689
}
8790

8891
// Rewrite prefilters
89-
boolean hasChanged = false;
92+
// We eagerly rewrite prefilters, because some of the innerRetrievers
93+
// could be compound too, so we want to propagate all the necessary filter information to them
94+
// and have it available as part of their own rewrite step
9095
var newPreFilters = rewritePreFilters(ctx);
91-
hasChanged |= newPreFilters != preFilterQueryBuilders;
96+
if (newPreFilters != preFilterQueryBuilders) {
97+
return clone(innerRetrievers, newPreFilters);
98+
}
9299

100+
boolean hasChanged = false;
93101
// Rewrite retriever sources
94102
List<RetrieverSource> newRetrievers = new ArrayList<>();
95103
for (var entry : innerRetrievers) {
104+
// we propagate the filters only for compound retrievers as they won't be attached through
105+
// the createSearchSourceBuilder.
106+
// We could remove this check, but we would end up adding the same filters
107+
// multiple times in case an inner retriever rewrites itself, when we re-enter to rewrite
108+
if (entry.retriever.isCompound() && false == preFilterQueryBuilders.isEmpty()) {
109+
entry.retriever.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
110+
}
96111
RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx);
97112
if (newRetriever != entry.retriever) {
98113
newRetrievers.add(new RetrieverSource(newRetriever, null));
@@ -107,7 +122,7 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio
107122
}
108123
}
109124
if (hasChanged) {
110-
return clone(newRetrievers);
125+
return clone(newRetrievers, newPreFilters);
111126
}
112127

113128
// execute searches
@@ -167,12 +182,7 @@ public void onFailure(Exception e) {
167182
});
168183
});
169184

170-
return new RankDocsRetrieverBuilder(
171-
rankWindowSize,
172-
newRetrievers.stream().map(s -> s.retriever).toList(),
173-
results::get,
174-
newPreFilters
175-
);
185+
return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get);
176186
}
177187

178188
@Override

server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,7 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
184184
ll.onResponse(null);
185185
}));
186186
});
187-
var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null);
188-
return rewritten;
187+
return new KnnRetrieverBuilder(this, () -> toSet.get(), null);
189188
}
190189
return super.rewrite(ctx);
191190
}

server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,13 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
3333
final List<RetrieverBuilder> sources;
3434
final Supplier<RankDoc[]> rankDocs;
3535

36-
public RankDocsRetrieverBuilder(
37-
int rankWindowSize,
38-
List<RetrieverBuilder> sources,
39-
Supplier<RankDoc[]> rankDocs,
40-
List<QueryBuilder> preFilterQueryBuilders
41-
) {
36+
public RankDocsRetrieverBuilder(int rankWindowSize, List<RetrieverBuilder> sources, Supplier<RankDoc[]> rankDocs) {
4237
this.rankWindowSize = rankWindowSize;
4338
this.rankDocs = rankDocs;
4439
if (sources == null || sources.isEmpty()) {
4540
throw new IllegalArgumentException("sources must not be null or empty");
4641
}
4742
this.sources = sources;
48-
this.preFilterQueryBuilders = preFilterQueryBuilders;
4943
}
5044

5145
@Override
@@ -73,10 +67,6 @@ private boolean sourceShouldRewrite(QueryRewriteContext ctx) throws IOException
7367
@Override
7468
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
7569
assert false == sourceShouldRewrite(ctx) : "retriever sources should be rewritten first";
76-
var rewrittenFilters = rewritePreFilters(ctx);
77-
if (rewrittenFilters != preFilterQueryBuilders) {
78-
return new RankDocsRetrieverBuilder(rankWindowSize, sources, rankDocs, rewrittenFilters);
79-
}
8070
return this;
8171
}
8272

@@ -94,7 +84,7 @@ public QueryBuilder topDocsQuery() {
9484
boolQuery.should(query);
9585
}
9686
}
97-
// ignore prefilters of this level, they are already propagated to children
87+
// ignore prefilters of this level, they were already propagated to children
9888
return boolQuery;
9989
}
10090

@@ -133,7 +123,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
133123
} else {
134124
rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
135125
}
136-
// ignore prefilters of this level, they are already propagated to children
126+
// ignore prefilters of this level, they were already propagated to children
137127
searchSourceBuilder.query(rankQuery);
138128
if (sourceHasMinScore()) {
139129
searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());

server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,7 @@ private List<QueryBuilder> preFilters(QueryRewriteContext queryRewriteContext) t
9595
}
9696

9797
private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder(QueryRewriteContext queryRewriteContext) throws IOException {
98-
return new RankDocsRetrieverBuilder(
99-
randomIntBetween(1, 100),
100-
innerRetrievers(queryRewriteContext),
101-
rankDocsSupplier(),
102-
preFilters(queryRewriteContext)
103-
);
98+
return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(queryRewriteContext), rankDocsSupplier());
10499
}
105100

106101
public void testExtractToSearchSourceBuilder() throws IOException {

server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
/**
2828
* A SearchPlugin to exercise query vector builder
2929
*/
30-
class TestQueryVectorBuilderPlugin implements SearchPlugin {
30+
public class TestQueryVectorBuilderPlugin implements SearchPlugin {
3131

32-
static class TestQueryVectorBuilder implements QueryVectorBuilder {
32+
public static class TestQueryVectorBuilder implements QueryVectorBuilder {
3333
private static final String NAME = "test_query_vector_builder";
3434

3535
private static final ParseField QUERY_VECTOR = new ParseField("query_vector");
@@ -47,11 +47,11 @@ static class TestQueryVectorBuilder implements QueryVectorBuilder {
4747

4848
private List<Float> vectorToBuild;
4949

50-
TestQueryVectorBuilder(List<Float> vectorToBuild) {
50+
public TestQueryVectorBuilder(List<Float> vectorToBuild) {
5151
this.vectorToBuild = vectorToBuild;
5252
}
5353

54-
TestQueryVectorBuilder(float[] expected) {
54+
public TestQueryVectorBuilder(float[] expected) {
5555
this.vectorToBuild = new ArrayList<>(expected.length);
5656
for (float f : expected) {
5757
vectorToBuild.add(f);

test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.search.retriever;
1111

1212
import org.apache.lucene.search.ScoreDoc;
13+
import org.elasticsearch.index.query.QueryBuilder;
1314
import org.elasticsearch.search.rank.RankDoc;
1415
import org.elasticsearch.xcontent.XContentBuilder;
1516

@@ -23,16 +24,17 @@ public class TestCompoundRetrieverBuilder extends CompoundRetrieverBuilder<TestC
2324
public static final String NAME = "test_compound_retriever_builder";
2425

2526
public TestCompoundRetrieverBuilder(int rankWindowSize) {
26-
this(new ArrayList<>(), rankWindowSize);
27+
this(new ArrayList<>(), rankWindowSize, new ArrayList<>());
2728
}
2829

29-
TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize) {
30+
TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, List<QueryBuilder> preFilterQueryBuilders) {
3031
super(childRetrievers, rankWindowSize);
32+
this.preFilterQueryBuilders = preFilterQueryBuilders;
3133
}
3234

3335
@Override
34-
protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
35-
return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize);
36+
protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
37+
return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize, newPreFilterQueryBuilders);
3638
}
3739

3840
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,10 @@ public TextSimilarityRankRetrieverBuilder(
130130
}
131131

132132
@Override
133-
protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
133+
protected TextSimilarityRankRetrieverBuilder clone(
134+
List<RetrieverSource> newChildRetrievers,
135+
List<QueryBuilder> newPreFilterQueryBuilders
136+
) {
134137
return new TextSimilarityRankRetrieverBuilder(
135138
newChildRetrievers,
136139
inferenceId,
@@ -139,7 +142,7 @@ protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChil
139142
rankWindowSize,
140143
minScore,
141144
retrieverName,
142-
preFilterQueryBuilders
145+
newPreFilterQueryBuilders
143146
);
144147
}
145148

x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.elasticsearch.search.sort.SortOrder;
3434
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
3535
import org.elasticsearch.search.vectors.QueryVectorBuilder;
36+
import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin;
3637
import org.elasticsearch.test.ESIntegTestCase;
3738
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
3839
import org.elasticsearch.xcontent.XContentBuilder;
@@ -57,7 +58,6 @@
5758
public class RRFRetrieverBuilderIT extends ESIntegTestCase {
5859

5960
protected static String INDEX = "test_index";
60-
protected static final String ID_FIELD = "_id";
6161
protected static final String DOC_FIELD = "doc";
6262
protected static final String TEXT_FIELD = "text";
6363
protected static final String VECTOR_FIELD = "vector";
@@ -743,6 +743,42 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
743743
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
744744
}
745745

746+
public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() {
747+
final int rankWindowSize = 100;
748+
final int rankConstant = 10;
749+
SearchSourceBuilder source = new SearchSourceBuilder();
750+
// this will retriever all but 7 only due to top-level filter
751+
StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery());
752+
// this will too retrieve just doc 7
753+
KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
754+
"vector",
755+
null,
756+
new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }),
757+
10,
758+
10,
759+
null
760+
);
761+
source.retriever(
762+
new RRFRetrieverBuilder(
763+
Arrays.asList(
764+
new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null),
765+
new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null)
766+
),
767+
rankWindowSize,
768+
rankConstant
769+
)
770+
);
771+
source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7")));
772+
source.size(10);
773+
SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
774+
ElasticsearchAssertions.assertResponse(req, resp -> {
775+
assertNull(resp.pointInTimeId());
776+
assertNotNull(resp.getHits().getTotalHits());
777+
assertThat(resp.getHits().getTotalHits().value, equalTo(1L));
778+
assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7"));
779+
});
780+
}
781+
746782
public void testRewriteOnce() {
747783
final float[] vector = new float[] { 1 };
748784
AtomicInteger numAsyncCalls = new AtomicInteger();

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import java.util.Set;
1414

15+
import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT;
1516
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder.RRF_RETRIEVER_COMPOSITION_SUPPORTED;
1617

1718
/**
@@ -23,4 +24,9 @@ public class RRFFeatures implements FeatureSpecification {
2324
public Set<NodeFeature> getFeatures() {
2425
return Set.of(RRFRetrieverBuilder.RRF_RETRIEVER_SUPPORTED, RRF_RETRIEVER_COMPOSITION_SUPPORTED);
2526
}
27+
28+
@Override
29+
public Set<NodeFeature> getTestFeatures() {
30+
return Set.of(INNER_RETRIEVERS_FILTER_SUPPORT);
31+
}
2632
}

0 commit comments

Comments
 (0)