Skip to content

Commit fc0e54b

Browse files
committed
Adds filters to KnnVectorQueryBuilder when translating
1 parent 61a1e4b commit fc0e54b

File tree

10 files changed

+101
-18
lines changed

10 files changed

+101
-18
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,14 @@ public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
168168

169169
@Override
170170
public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
171-
return queryBuilder != null ? new TranslationAwareExpressionQuery(source(), queryBuilder) : translate(handler);
171+
return queryBuilder != null ? new TranslationAwareExpressionQuery(source(), queryBuilder) : translate(pushdownPredicates, handler);
172172
}
173173

174174
public QueryBuilder queryBuilder() {
175175
return queryBuilder;
176176
}
177177

178-
protected abstract Query translate(TranslatorHandler handler);
178+
protected abstract Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler);
179179

180180
public abstract Expression replaceQueryBuilder(QueryBuilder queryBuilder);
181181

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Kql.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
2323
import org.elasticsearch.xpack.esql.expression.function.Param;
2424
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
25+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
2526
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
2627
import org.elasticsearch.xpack.esql.querydsl.query.KqlQuery;
2728

@@ -93,7 +94,7 @@ protected NodeInfo<? extends Expression> info() {
9394
}
9495

9596
@Override
96-
protected Query translate(TranslatorHandler handler) {
97+
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
9798
return new KqlQuery(source(), Objects.toString(queryAsObject()));
9899
}
99100

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
3636
import org.elasticsearch.xpack.esql.expression.function.Param;
3737
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
38+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
3839
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
3940
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
4041
import org.elasticsearch.xpack.esql.querydsl.query.MatchQuery;
@@ -423,7 +424,7 @@ public Object queryAsObject() {
423424
}
424425

425426
@Override
426-
protected Query translate(TranslatorHandler handler) {
427+
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
427428
var fieldAttribute = fieldAsFieldAttribute();
428429
Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument");
429430
String fieldName = getNameFromFieldAttribute(fieldAttribute);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
3333
import org.elasticsearch.xpack.esql.expression.function.Param;
3434
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
35+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
3536
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
3637
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
3738
import org.elasticsearch.xpack.esql.querydsl.query.MatchPhraseQuery;
@@ -278,7 +279,7 @@ public Object queryAsObject() {
278279
}
279280

280281
@Override
281-
protected Query translate(TranslatorHandler handler) {
282+
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
282283
var fieldAttribute = fieldAsFieldAttribute();
283284
Check.notNull(fieldAttribute, "MatchPhrase must have a field attribute as the first argument");
284285
String fieldName = getNameFromFieldAttribute(fieldAttribute);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
3232
import org.elasticsearch.xpack.esql.expression.function.Param;
3333
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
34+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
3435
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
3536
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
3637
import org.elasticsearch.xpack.esql.querydsl.query.MultiMatchQuery;
@@ -335,7 +336,7 @@ protected NodeInfo<? extends Expression> info() {
335336
}
336337

337338
@Override
338-
protected Query translate(TranslatorHandler handler) {
339+
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
339340
Map<String, Float> fieldsWithBoost = new HashMap<>();
340341
for (Expression field : fields) {
341342
var fieldAttribute = Match.fieldAsFieldAttribute(field);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/QueryString.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
2929
import org.elasticsearch.xpack.esql.expression.function.Param;
3030
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
31+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
3132
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
3233

3334
import java.io.IOException;
@@ -345,7 +346,7 @@ protected NodeInfo<? extends Expression> info() {
345346
}
346347

347348
@Override
348-
protected Query translate(TranslatorHandler handler) {
349+
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
349350
return new QueryStringQuery(source(), Objects.toString(queryAsObject()), Map.of(), queryStringOptions());
350351
}
351352

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
2828
import org.elasticsearch.xpack.esql.expression.function.Param;
2929
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
30+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
3031
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
3132
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
3233

@@ -130,7 +131,7 @@ protected TypeResolutions.ParamOrdinal queryParamOrdinal() {
130131
}
131132

132133
@Override
133-
protected Query translate(TranslatorHandler handler) {
134+
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
134135
// Uses a term query that contributes to scoring
135136
return new TermQuery(source(), ((FieldAttribute) field()).name(), queryAsObject(), false, true);
136137
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.common.io.stream.StreamOutput;
1414
import org.elasticsearch.index.query.QueryBuilder;
1515
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
16+
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
1617
import org.elasticsearch.xpack.esql.common.Failures;
1718
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
1819
import org.elasticsearch.xpack.esql.core.expression.Expression;
@@ -34,6 +35,7 @@
3435
import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction;
3536
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
3637
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
38+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
3739
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
3840
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
3941
import org.elasticsearch.xpack.esql.querydsl.query.KnnQuery;
@@ -252,10 +254,10 @@ private Map<String, Object> knnQueryOptions() throws InvalidArgumentException {
252254
}
253255

254256
@Override
255-
protected Query translate(TranslatorHandler handler) {
257+
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
256258
var fieldAttribute = Match.fieldAsFieldAttribute(field());
257259

258-
Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument");
260+
Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument");
259261
String fieldName = getNameFromFieldAttribute(fieldAttribute);
260262
@SuppressWarnings("unchecked")
261263
List<Number> queryFolded = (List<Number>) query().fold(FoldContext.small() /* TODO remove me */);
@@ -268,14 +270,30 @@ protected Query translate(TranslatorHandler handler) {
268270
Map<String, Object> opts = queryOptions();
269271
opts.put(K_FIELD.getPreferredName(), kValue);
270272

271-
return new KnnQuery(source(), fieldName, queryAsFloats, opts);
273+
List<QueryBuilder> filterQueries = new ArrayList<>();
274+
for (Expression filterExpression : filterExpressions()) {
275+
filterQueries.add(handler.asQuery(pushdownPredicates, filterExpression).toQueryBuilder());
276+
}
277+
278+
return new KnnQuery(source(), fieldName, queryAsFloats, opts, filterQueries);
272279
}
273280

274281
@Override
275282
public Expression replaceQueryBuilder(QueryBuilder queryBuilder) {
276283
return new Knn(source(), field(), query(), k(), options(), queryBuilder, filterExpressions());
277284
}
278285

286+
@Override
287+
public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
288+
Translatable translatable = super.translatable(pushdownPredicates);
289+
// We need to check whether filter expressions are translatable as well
290+
for(Expression filterExpression : filterExpressions()) {
291+
translatable = translatable.merge(TranslationAware.translatable(filterExpression, pushdownPredicates));
292+
}
293+
294+
return translatable;
295+
}
296+
279297
public Expression withFilters(List<Expression> filterExpressions) {
280298
return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions);
281299
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/KnnQuery.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ public class KnnQuery extends Query {
3333

3434
public static final String RESCORE_OVERSAMPLE_FIELD = "rescore_oversample";
3535

36-
public KnnQuery(Source source, String field, float[] query, Map<String, Object> options) {
37-
this(source, field, query, options, List.of());
38-
}
39-
4036
public KnnQuery(Source source, String field, float[] query, Map<String, Object> options, List<QueryBuilder> filterQueries) {
4137
super(source);
4238
assert options != null;
@@ -67,7 +63,7 @@ protected QueryBuilder asBuilder() {
6763
}
6864
return queryBuilder;
6965
}
70-
66+
7167
public KnnQuery withFilterQueries(List<QueryBuilder> newFilterQueries) {
7268
List<QueryBuilder> combinedFilterQueries = new ArrayList<>(filterQueries);
7369
combinedFilterQueries.addAll(newFilterQueries);
@@ -100,7 +96,7 @@ public int hashCode() {
10096
public boolean scorable() {
10197
return true;
10298
}
103-
99+
104100
public List<QueryBuilder> filterQueries() {
105101
return filterQueries;
106102
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
6161
import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator;
6262
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
63+
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
64+
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
6365
import org.elasticsearch.xpack.esql.expression.predicate.logical.Or;
6466
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
6567
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual;
@@ -104,6 +106,7 @@
104106

105107
import java.io.IOException;
106108
import java.util.ArrayList;
109+
import java.util.Arrays;
107110
import java.util.Collection;
108111
import java.util.List;
109112
import java.util.Locale;
@@ -1626,6 +1629,10 @@ private void testFullTextFunctionWithPushableConjunction(FullTextFunctionTestCas
16261629
assertEquals(expected.toString(), esQuery.query().toString());
16271630
}
16281631

1632+
public void testKnn() {
1633+
testFullTextFunctionWithNonPushableDisjunction(new KnnFunctionTestCase());
1634+
}
1635+
16291636
private void testFullTextFunctionWithNonPushableDisjunction(FullTextFunctionTestCase testCase) {
16301637
String query = String.format(Locale.ROOT, """
16311638
from test
@@ -1646,6 +1653,32 @@ private void testFullTextFunctionWithNonPushableDisjunction(FullTextFunctionTest
16461653
assertThat(fieldExtract.child(), instanceOf(EsQueryExec.class));
16471654
}
16481655

1656+
public void testKnnPrefilters() {
1657+
String query = """
1658+
from test
1659+
| where knn(dense_vector, [0, 1, 2], 10) and integer > 10
1660+
""";
1661+
var plan = plannerOptimizer.plan(query, IS_SV_STATS, makeAnalyzer("mapping-all-types.json"));
1662+
1663+
var limit = as(plan, LimitExec.class);
1664+
var exchange = as(limit.child(), ExchangeExec.class);
1665+
var project = as(exchange.child(), ProjectExec.class);
1666+
var field = as(project.child(), FieldExtractExec.class);
1667+
var queryExec = as(field.child(), EsQueryExec.class);
1668+
QueryBuilder expectedFilterQueryBuilder = wrapWithSingleQuery(
1669+
query,
1670+
unscore(rangeQuery("integer").gt(10)),
1671+
"integer",
1672+
new Source(2, 45, "integer > 10")
1673+
);
1674+
KnnVectorQueryBuilder expectedKnnQueryBuilder = new KnnVectorQueryBuilder("dense_vector", new float[] { 0, 1, 2 }, 10, null, null, null)
1675+
.addFilterQuery(expectedFilterQueryBuilder);
1676+
var expectedQuery = boolQuery()
1677+
.must(expectedKnnQueryBuilder)
1678+
.must(expectedFilterQueryBuilder);
1679+
assertEquals(expectedQuery.toString(), queryExec.query().toString());
1680+
}
1681+
16491682
private void testFullTextFunctionWithPushableDisjunction(FullTextFunctionTestCase testCase) {
16501683
String query = String.format(Locale.ROOT, """
16511684
from test
@@ -1665,11 +1698,12 @@ private void testFullTextFunctionWithPushableDisjunction(FullTextFunctionTestCas
16651698
}
16661699

16671700
private FullTextFunctionTestCase randomFullTextFunctionTestCase() {
1668-
return switch (randomIntBetween(0, 3)) {
1701+
return switch (randomIntBetween(0, 4)) {
16691702
case 0 -> new MatchFunctionTestCase();
16701703
case 1 -> new MatchOperatorTestCase();
16711704
case 2 -> new KqlFunctionTestCase();
16721705
case 3 -> new QueryStringFunctionTestCase();
1706+
case 4 -> new KnnFunctionTestCase();
16731707
default -> throw new IllegalStateException("Unexpected value");
16741708
};
16751709
}
@@ -2190,4 +2224,33 @@ public String esqlQuery() {
21902224
return "qstr(\"" + fieldName() + ": " + queryString() + "\")";
21912225
}
21922226
}
2227+
2228+
private class KnnFunctionTestCase extends FullTextFunctionTestCase {
2229+
2230+
final int k;
2231+
2232+
KnnFunctionTestCase() {
2233+
super(Knn.class, "dense_vector", randomVector());
2234+
k = randomIntBetween(1, 10);
2235+
}
2236+
2237+
private static Object randomVector() {
2238+
int numDims = randomIntBetween(10, 20);
2239+
float[] vector = new float[numDims];
2240+
for (int i = 0; i < numDims; i++) {
2241+
vector[i] = randomFloat();
2242+
}
2243+
return vector;
2244+
}
2245+
2246+
@Override
2247+
public QueryBuilder queryBuilder() {
2248+
return new KnnVectorQueryBuilder(fieldName(), (float[]) queryString(), k, null, null, null);
2249+
}
2250+
2251+
@Override
2252+
public String esqlQuery() {
2253+
return "knn(" + fieldName() + ", " + Arrays.toString(((float[]) queryString())) + ", " + k + ")";
2254+
}
2255+
}
21932256
}

0 commit comments

Comments
 (0)