Skip to content

Commit 1708d9e

Browse files
authored
Ensure that all rewriteable are called in retrievers (#114366)
This PR ensures that all retriever applies the rewrite to all their rewriteable. Rewriting eagerly at the retriever level ensures that we don't rewrite the same query multiple times when compound retrievers are used.
1 parent db8a2d2 commit 1708d9e

File tree

5 files changed

+204
-10
lines changed

5 files changed

+204
-10
lines changed

server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
package org.elasticsearch.search.retriever;
1111

12+
import org.apache.lucene.util.SetOnce;
1213
import org.elasticsearch.common.ParsingException;
1314
import org.elasticsearch.features.NodeFeature;
1415
import org.elasticsearch.index.query.BoolQueryBuilder;
1516
import org.elasticsearch.index.query.QueryBuilder;
17+
import org.elasticsearch.index.query.QueryRewriteContext;
1618
import org.elasticsearch.search.builder.SearchSourceBuilder;
1719
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
1820
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
@@ -29,7 +31,9 @@
2931
import java.util.Arrays;
3032
import java.util.List;
3133
import java.util.Objects;
34+
import java.util.function.Supplier;
3235

36+
import static org.elasticsearch.common.Strings.format;
3337
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
3438
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
3539

@@ -96,7 +100,7 @@ public static KnnRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
96100
}
97101

98102
private final String field;
99-
private final float[] queryVector;
103+
private final Supplier<float[]> queryVector;
100104
private final QueryVectorBuilder queryVectorBuilder;
101105
private final int k;
102106
private final int numCands;
@@ -110,23 +114,85 @@ public KnnRetrieverBuilder(
110114
int numCands,
111115
Float similarity
112116
) {
117+
if (queryVector == null && queryVectorBuilder == null) {
118+
throw new IllegalArgumentException(
119+
format(
120+
"either [%s] or [%s] must be provided",
121+
QUERY_VECTOR_FIELD.getPreferredName(),
122+
QUERY_VECTOR_BUILDER_FIELD.getPreferredName()
123+
)
124+
);
125+
} else if (queryVector != null && queryVectorBuilder != null) {
126+
throw new IllegalArgumentException(
127+
format(
128+
"only one of [%s] and [%s] must be provided",
129+
QUERY_VECTOR_FIELD.getPreferredName(),
130+
QUERY_VECTOR_BUILDER_FIELD.getPreferredName()
131+
)
132+
);
133+
}
113134
this.field = field;
114-
this.queryVector = queryVector;
135+
this.queryVector = queryVector != null ? () -> queryVector : null;
115136
this.queryVectorBuilder = queryVectorBuilder;
116137
this.k = k;
117138
this.numCands = numCands;
118139
this.similarity = similarity;
119140
}
120141

121-
// ---- FOR TESTING XCONTENT PARSING ----
142+
private KnnRetrieverBuilder(KnnRetrieverBuilder clone, Supplier<float[]> queryVector, QueryVectorBuilder queryVectorBuilder) {
143+
this.queryVector = queryVector;
144+
this.queryVectorBuilder = queryVectorBuilder;
145+
this.field = clone.field;
146+
this.k = clone.k;
147+
this.numCands = clone.numCands;
148+
this.similarity = clone.similarity;
149+
this.retrieverName = clone.retrieverName;
150+
this.preFilterQueryBuilders = clone.preFilterQueryBuilders;
151+
}
122152

123153
@Override
124154
public String getName() {
125155
return NAME;
126156
}
127157

158+
@Override
159+
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
160+
var rewrittenFilters = rewritePreFilters(ctx);
161+
if (rewrittenFilters != preFilterQueryBuilders) {
162+
var rewritten = new KnnRetrieverBuilder(this, queryVector, queryVectorBuilder);
163+
rewritten.preFilterQueryBuilders = rewrittenFilters;
164+
return rewritten;
165+
}
166+
167+
if (queryVectorBuilder != null) {
168+
SetOnce<float[]> toSet = new SetOnce<>();
169+
ctx.registerAsyncAction((c, l) -> {
170+
queryVectorBuilder.buildVector(c, l.delegateFailureAndWrap((ll, v) -> {
171+
toSet.set(v);
172+
if (v == null) {
173+
ll.onFailure(
174+
new IllegalArgumentException(
175+
format(
176+
"[%s] with name [%s] returned null query_vector",
177+
QUERY_VECTOR_BUILDER_FIELD.getPreferredName(),
178+
queryVectorBuilder.getWriteableName()
179+
)
180+
)
181+
);
182+
return;
183+
}
184+
ll.onResponse(null);
185+
}));
186+
});
187+
var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null);
188+
return rewritten;
189+
}
190+
return super.rewrite(ctx);
191+
}
192+
128193
@Override
129194
public QueryBuilder topDocsQuery() {
195+
assert queryVector != null : "query vector must be materialized at this point";
130196
assert rankDocs != null : "rankDocs should have been materialized by now";
131197
var rankDocsQuery = new RankDocsQueryBuilder(rankDocs, null, true);
132198
if (preFilterQueryBuilders.isEmpty()) {
@@ -139,10 +205,11 @@ public QueryBuilder topDocsQuery() {
139205

140206
@Override
141207
public QueryBuilder explainQuery() {
208+
assert queryVector != null : "query vector must be materialized at this point";
142209
assert rankDocs != null : "rankDocs should have been materialized by now";
143210
var rankDocsQuery = new RankDocsQueryBuilder(
144211
rankDocs,
145-
new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector), field, similarity) },
212+
new QueryBuilder[] { new ExactKnnQueryBuilder(VectorData.fromFloats(queryVector.get()), field, similarity) },
146213
true
147214
);
148215
if (preFilterQueryBuilders.isEmpty()) {
@@ -155,10 +222,11 @@ public QueryBuilder explainQuery() {
155222

156223
@Override
157224
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
225+
assert queryVector != null : "query vector must be materialized at this point.";
158226
KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder(
159227
field,
160-
VectorData.fromFloats(queryVector),
161-
queryVectorBuilder,
228+
VectorData.fromFloats(queryVector.get()),
229+
null,
162230
k,
163231
numCands,
164232
similarity
@@ -174,14 +242,16 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
174242
searchSourceBuilder.knnSearch(knnSearchBuilders);
175243
}
176244

245+
// ---- FOR TESTING XCONTENT PARSING ----
246+
177247
@Override
178248
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
179249
builder.field(FIELD_FIELD.getPreferredName(), field);
180250
builder.field(K_FIELD.getPreferredName(), k);
181251
builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands);
182252

183253
if (queryVector != null) {
184-
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
254+
builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector.get());
185255
}
186256

187257
if (queryVectorBuilder != null) {
@@ -199,15 +269,16 @@ public boolean doEquals(Object o) {
199269
return k == that.k
200270
&& numCands == that.numCands
201271
&& Objects.equals(field, that.field)
202-
&& Arrays.equals(queryVector, that.queryVector)
272+
&& ((queryVector == null && that.queryVector == null)
273+
|| (queryVector != null && that.queryVector != null && Arrays.equals(queryVector.get(), that.queryVector.get())))
203274
&& Objects.equals(queryVectorBuilder, that.queryVectorBuilder)
204275
&& Objects.equals(similarity, that.similarity);
205276
}
206277

207278
@Override
208279
public int doHashCode() {
209280
int result = Objects.hash(field, queryVectorBuilder, k, numCands, similarity);
210-
result = 31 * result + Arrays.hashCode(queryVector);
281+
result = 31 * result + Arrays.hashCode(queryVector != null ? queryVector.get() : null);
211282
return result;
212283
}
213284

server/src/main/java/org/elasticsearch/search/retriever/StandardRetrieverBuilder.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.index.query.AbstractQueryBuilder;
1515
import org.elasticsearch.index.query.BoolQueryBuilder;
1616
import org.elasticsearch.index.query.QueryBuilder;
17+
import org.elasticsearch.index.query.QueryRewriteContext;
1718
import org.elasticsearch.search.builder.SearchSourceBuilder;
1819
import org.elasticsearch.search.builder.SubSearchSourceBuilder;
1920
import org.elasticsearch.search.collapse.CollapseBuilder;
@@ -27,6 +28,7 @@
2728
import org.elasticsearch.xcontent.XContentParser;
2829

2930
import java.io.IOException;
31+
import java.util.ArrayList;
3032
import java.util.List;
3133
import java.util.Objects;
3234

@@ -105,6 +107,48 @@ public StandardRetrieverBuilder(QueryBuilder queryBuilder) {
105107
this.queryBuilder = queryBuilder;
106108
}
107109

110+
private StandardRetrieverBuilder(StandardRetrieverBuilder clone) {
111+
this.retrieverName = clone.retrieverName;
112+
this.queryBuilder = clone.queryBuilder;
113+
this.minScore = clone.minScore;
114+
this.sortBuilders = clone.sortBuilders;
115+
this.preFilterQueryBuilders = clone.preFilterQueryBuilders;
116+
this.collapseBuilder = clone.collapseBuilder;
117+
this.searchAfterBuilder = clone.searchAfterBuilder;
118+
this.terminateAfter = clone.terminateAfter;
119+
}
120+
121+
@Override
122+
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
123+
boolean changed = false;
124+
List<SortBuilder<?>> newSortBuilders = null;
125+
if (sortBuilders != null) {
126+
newSortBuilders = new ArrayList<>(sortBuilders.size());
127+
for (var sort : sortBuilders) {
128+
var newSort = sort.rewrite(ctx);
129+
newSortBuilders.add(newSort);
130+
changed = newSort != sort;
131+
}
132+
}
133+
var rewrittenFilters = rewritePreFilters(ctx);
134+
changed |= rewrittenFilters != preFilterQueryBuilders;
135+
136+
QueryBuilder queryBuilderRewrite = null;
137+
if (queryBuilder != null) {
138+
queryBuilderRewrite = queryBuilder.rewrite(ctx);
139+
changed |= queryBuilderRewrite != queryBuilder;
140+
}
141+
142+
if (changed) {
143+
var rewritten = new StandardRetrieverBuilder(this);
144+
rewritten.sortBuilders = newSortBuilders;
145+
rewritten.preFilterQueryBuilders = preFilterQueryBuilders;
146+
rewritten.queryBuilder = queryBuilderRewrite;
147+
return rewritten;
148+
}
149+
return this;
150+
}
151+
108152
@Override
109153
public QueryBuilder topDocsQuery() {
110154
if (preFilterQueryBuilders.isEmpty()) {

server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.common.io.stream.StreamOutput;
1919
import org.elasticsearch.index.query.AbstractQueryBuilder;
2020
import org.elasticsearch.index.query.QueryBuilder;
21+
import org.elasticsearch.index.query.QueryRewriteContext;
2122
import org.elasticsearch.index.query.SearchExecutionContext;
2223
import org.elasticsearch.search.rank.RankDoc;
2324
import org.elasticsearch.xcontent.XContentBuilder;
@@ -54,6 +55,22 @@ public RankDocsQueryBuilder(StreamInput in) throws IOException {
5455
}
5556
}
5657

58+
@Override
59+
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
60+
if (queryBuilders != null) {
61+
QueryBuilder[] newQueryBuilders = new QueryBuilder[queryBuilders.length];
62+
boolean changed = false;
63+
for (int i = 0; i < newQueryBuilders.length; i++) {
64+
newQueryBuilders[i] = queryBuilders[i].rewrite(queryRewriteContext);
65+
changed |= newQueryBuilders[i] != queryBuilders[i];
66+
}
67+
if (changed) {
68+
return new RankDocsQueryBuilder(rankDocs, newQueryBuilders, onlyRankDocs);
69+
}
70+
}
71+
return super.doRewrite(queryRewriteContext);
72+
}
73+
5774
RankDoc[] rankDocs() {
5875
return rankDocs;
5976
}

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer k, I
125125
this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity);
126126
}
127127

128-
protected KnnVectorQueryBuilder(
128+
public KnnVectorQueryBuilder(
129129
String fieldName,
130130
QueryVectorBuilder queryVectorBuilder,
131131
Integer k,

x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
package org.elasticsearch.xpack.rank.rrf;
99

1010
import org.apache.lucene.search.TotalHits;
11+
import org.elasticsearch.TransportVersion;
12+
import org.elasticsearch.action.ActionListener;
1113
import org.elasticsearch.action.search.SearchRequestBuilder;
14+
import org.elasticsearch.client.internal.Client;
15+
import org.elasticsearch.common.io.stream.StreamOutput;
1216
import org.elasticsearch.common.settings.Settings;
1317
import org.elasticsearch.index.query.InnerHitBuilder;
1418
import org.elasticsearch.index.query.QueryBuilder;
@@ -24,16 +28,23 @@
2428
import org.elasticsearch.search.retriever.TestRetrieverBuilder;
2529
import org.elasticsearch.search.sort.FieldSortBuilder;
2630
import org.elasticsearch.search.sort.SortOrder;
31+
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
32+
import org.elasticsearch.search.vectors.QueryVectorBuilder;
2733
import org.elasticsearch.test.ESIntegTestCase;
2834
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
35+
import org.elasticsearch.xcontent.XContentBuilder;
2936
import org.elasticsearch.xcontent.XContentType;
3037
import org.junit.Before;
3138

39+
import java.io.IOException;
3240
import java.util.Arrays;
3341
import java.util.Collection;
3442
import java.util.List;
43+
import java.util.concurrent.atomic.AtomicInteger;
3544

3645
import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
46+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
47+
import static org.hamcrest.CoreMatchers.is;
3748
import static org.hamcrest.Matchers.containsString;
3849
import static org.hamcrest.Matchers.equalTo;
3950
import static org.hamcrest.Matchers.greaterThan;
@@ -652,4 +663,55 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder
652663
source.aggregation(AggregationBuilders.terms("topic_agg").field(TOPIC_FIELD));
653664
expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
654665
}
666+
667+
public void testRewriteOnce() {
668+
final float[] vector = new float[] { 1 };
669+
AtomicInteger numAsyncCalls = new AtomicInteger();
670+
QueryVectorBuilder vectorBuilder = new QueryVectorBuilder() {
671+
@Override
672+
public void buildVector(Client client, ActionListener<float[]> listener) {
673+
numAsyncCalls.incrementAndGet();
674+
listener.onResponse(vector);
675+
}
676+
677+
@Override
678+
public String getWriteableName() {
679+
throw new IllegalStateException("Should not be called");
680+
}
681+
682+
@Override
683+
public TransportVersion getMinimalSupportedVersion() {
684+
throw new IllegalStateException("Should not be called");
685+
}
686+
687+
@Override
688+
public void writeTo(StreamOutput out) throws IOException {
689+
throw new IllegalStateException("Should not be called");
690+
}
691+
692+
@Override
693+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
694+
throw new IllegalStateException("Should not be called");
695+
}
696+
};
697+
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null);
698+
var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null));
699+
var rrf = new RRFRetrieverBuilder(
700+
List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)),
701+
10,
702+
10
703+
);
704+
assertResponse(
705+
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf)),
706+
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
707+
);
708+
assertThat(numAsyncCalls.get(), equalTo(2));
709+
710+
// check that we use the rewritten vector to build the explain query
711+
assertResponse(
712+
client().prepareSearch(INDEX).setSource(new SearchSourceBuilder().retriever(rrf).explain(true)),
713+
searchResponse -> assertThat(searchResponse.getHits().getTotalHits().value, is(4L))
714+
);
715+
assertThat(numAsyncCalls.get(), equalTo(4));
716+
}
655717
}

0 commit comments

Comments
 (0)