Skip to content

Commit 6ffe27d

Browse files
authored
ESQL - KNN function uses prefilters when pushed down to Lucene (#131004)
1 parent 7146681 commit 6ffe27d

File tree

23 files changed

+842
-69
lines changed

23 files changed

+842
-69
lines changed

docs/reference/query-languages/esql/kibana/docs/functions/knn.md

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# top-n query at the shard level
44

55
knnSearch
6-
required_capability: knn_function_v2
6+
required_capability: knn_function_v3
77

88
// tag::knn-function[]
99
from colors metadata _score
@@ -30,7 +30,7 @@ chartreuse | [127.0, 255.0, 0.0]
3030
;
3131

3232
knnSearchWithSimilarityOption
33-
required_capability: knn_function_v2
33+
required_capability: knn_function_v3
3434

3535
from colors metadata _score
3636
| where knn(rgb_vector, [255,192,203], 140, {"similarity": 40})
@@ -46,14 +46,13 @@ wheat | [245.0, 222.0, 179.0]
4646
;
4747

4848
knnHybridSearch
49-
required_capability: knn_function_v2
49+
required_capability: knn_function_v3
5050

5151
from colors metadata _score
52-
| where match(color, "blue") or knn(rgb_vector, [65,105,225], 140)
52+
| where match(color, "blue") or knn(rgb_vector, [65,105,225], 10)
5353
| where primary == true
5454
| sort _score desc, color asc
5555
| keep color, rgb_vector
56-
| limit 10
5756
;
5857

5958
color:text | rgb_vector:dense_vector
@@ -68,21 +67,45 @@ red | [255.0, 0.0, 0.0]
6867
yellow | [255.0, 255.0, 0.0]
6968
;
7069

71-
knnWithMultipleFunctions
72-
required_capability: knn_function_v2
70+
knnWithPrefilter
71+
required_capability: knn_function_v3
7372

7473
from colors metadata _score
75-
| where knn(rgb_vector, [128,128,0], 140) and match(color, "olive")
74+
| where knn(rgb_vector, [128,128,0], 10) and (match(color, "olive") or match(color, "green"))
7675
| sort _score desc, color asc
7776
| keep color, rgb_vector
7877
;
7978

8079
color:text | rgb_vector:dense_vector
8180
olive | [128.0, 128.0, 0.0]
81+
green | [0.0, 128.0, 0.0]
82+
;
83+
84+
knnWithNegatedPrefilter
85+
required_capability: knn_function_v3
86+
87+
from colors metadata _score
88+
| where knn(rgb_vector, [128,128,0], 10) and not (match(color, "olive") or match(color, "chocolate"))
89+
| sort _score desc, color asc
90+
| keep color, rgb_vector
91+
| LIMIT 10
92+
;
93+
94+
color:text | rgb_vector:dense_vector
95+
sienna | [160.0, 82.0, 45.0]
96+
peru | [205.0, 133.0, 63.0]
97+
golden rod | [218.0, 165.0, 32.0]
98+
brown | [165.0, 42.0, 42.0]
99+
firebrick | [178.0, 34.0, 34.0]
100+
chartreuse | [127.0, 255.0, 0.0]
101+
gray | [128.0, 128.0, 128.0]
102+
green | [0.0, 128.0, 0.0]
103+
maroon | [128.0, 0.0, 0.0]
104+
orange | [255.0, 165.0, 0.0]
82105
;
83106

84107
knnAfterKeep
85-
required_capability: knn_function_v2
108+
required_capability: knn_function_v3
86109

87110
from colors metadata _score
88111
| keep rgb_vector, color, _score
@@ -101,7 +124,7 @@ rgb_vector:dense_vector
101124
;
102125

103126
knnAfterDrop
104-
required_capability: knn_function_v2
127+
required_capability: knn_function_v3
105128

106129
from colors metadata _score
107130
| drop primary
@@ -120,7 +143,7 @@ lime | [0.0, 255.0, 0.0]
120143
;
121144

122145
knnAfterEval
123-
required_capability: knn_function_v2
146+
required_capability: knn_function_v3
124147

125148
from colors metadata _score
126149
| eval composed_name = locate(color, " ") > 0
@@ -139,14 +162,12 @@ golden rod | true
139162
;
140163

141164
knnWithConjunction
142-
required_capability: knn_function_v2
165+
required_capability: knn_function_v3
143166

144-
# TODO We need kNN prefiltering here so we get more candidates that pass the filter
145167
from colors metadata _score
146-
| where knn(rgb_vector, [255,255,238], 140) and hex_code like "#FFF*"
168+
| where knn(rgb_vector, [255,255,238], 10) and hex_code like "#FFF*"
147169
| sort _score desc, color asc
148170
| keep color, hex_code, rgb_vector
149-
| limit 10
150171
;
151172

152173
color:text | hex_code:keyword | rgb_vector:dense_vector
@@ -160,11 +181,10 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0]
160181
;
161182

162183
knnWithDisjunctionAndFiltersConjunction
163-
required_capability: knn_function_v2
184+
required_capability: knn_function_v3
164185

165-
# TODO We need kNN prefiltering here so we get more candidates that pass the filter
166186
from colors metadata _score
167-
| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 140)) and primary == true
187+
| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 10)) and primary == true
168188
| keep color, rgb_vector, _score
169189
| sort _score desc, color asc
170190
| drop _score
@@ -183,8 +203,31 @@ red | [255.0, 0.0, 0.0]
183203
yellow | [255.0, 255.0, 0.0]
184204
;
185205

206+
knnWithNegationsAndFiltersConjunction
207+
required_capability: knn_function_v3
208+
209+
from colors metadata _score
210+
| where (knn(rgb_vector, [0,255,255], 140) and not(primary == true and match(color, "blue")))
211+
| sort _score desc, color asc
212+
| keep color, rgb_vector
213+
| limit 10
214+
;
215+
216+
color:text | rgb_vector:dense_vector
217+
cyan | [0.0, 255.0, 255.0]
218+
turquoise | [64.0, 224.0, 208.0]
219+
aqua marine | [127.0, 255.0, 212.0]
220+
teal | [0.0, 128.0, 128.0]
221+
silver | [192.0, 192.0, 192.0]
222+
gray | [128.0, 128.0, 128.0]
223+
gainsboro | [220.0, 220.0, 220.0]
224+
thistle | [216.0, 191.0, 216.0]
225+
lavender | [230.0, 230.0, 250.0]
226+
azure | [240.0, 255.0, 255.0]
227+
;
228+
186229
knnWithNonPushableConjunction
187-
required_capability: knn_function_v2
230+
required_capability: knn_function_v3
188231

189232
from colors metadata _score
190233
| eval composed_name = locate(color, " ") > 0
@@ -208,7 +251,7 @@ maroon | false
208251
;
209252

210253
testKnnWithNonPushableDisjunctions
211-
required_capability: knn_function_v2
254+
required_capability: knn_function_v3
212255

213256
from colors metadata _score
214257
| where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10
@@ -224,7 +267,7 @@ papaya whip
224267
;
225268

226269
testKnnWithNonPushableDisjunctionsOnComplexExpressions
227-
required_capability: knn_function_v2
270+
required_capability: knn_function_v3
228271

229272
from colors metadata _score
230273
| where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false)
@@ -239,7 +282,7 @@ indigo | false
239282
;
240283

241284
testKnnInStatsNonPushable
242-
required_capability: knn_function_v2
285+
required_capability: knn_function_v3
243286

244287
from colors
245288
| where length(color) < 10
@@ -251,7 +294,7 @@ c: long
251294
;
252295

253296
testKnnInStatsWithGrouping
254-
required_capability: knn_function_v2
297+
required_capability: knn_function_v3
255298
required_capability: full_text_functions_in_stats_where
256299

257300
from colors

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,29 @@ public void testKnnNonPushedDown() {
114114
}
115115
}
116116

117+
public void testKnnWithPrefilters() {
118+
float[] queryVector = new float[numDims];
119+
Arrays.fill(queryVector, 1.0f);
120+
121+
// We retrieve 5 from knn, but must be prefiltered with id > 5 or no result will be returned as it would be post-filtered
122+
var query = String.format(Locale.ROOT, """
123+
FROM test METADATA _score
124+
| WHERE knn(vector, %s, 5) AND id > 5
125+
| KEEP id, floats, _score, vector
126+
| SORT _score DESC
127+
| LIMIT 5
128+
""", Arrays.toString(queryVector));
129+
130+
try (var resp = run(query)) {
131+
assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
132+
assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector"));
133+
134+
List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
135+
// K = 5, 1 more for every id > 10
136+
assertEquals(5, valuesList.size());
137+
}
138+
}
139+
117140
public void testKnnWithLookupJoin() {
118141
float[] queryVector = new float[numDims];
119142
Arrays.fill(queryVector, 1.0f);
@@ -136,7 +159,7 @@ public void testKnnWithLookupJoin() {
136159

137160
@Before
138161
public void setup() throws IOException {
139-
assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled());
162+
assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());
140163

141164
var indexName = "test";
142165
var client = client().admin().indices();
@@ -163,7 +186,7 @@ public void setup() throws IOException {
163186
var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build());
164187
assertAcked(createRequest);
165188

166-
numDocs = randomIntBetween(10, 20);
189+
numDocs = randomIntBetween(15, 25);
167190
numDims = randomIntBetween(3, 10);
168191
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
169192
float value = 0.0f;
@@ -202,6 +225,5 @@ private void createAndPopulateLookupIndex(IndicesAdminClient client, String look
202225

203226
var createRequest = client.prepareCreate(lookupIndexName).setMapping(mapping).setSettings(settingsBuilder.build());
204227
assertAcked(createRequest);
205-
206228
}
207229
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1213,7 +1213,7 @@ public enum Cap {
12131213
/**
12141214
* Support knn function
12151215
*/
1216-
KNN_FUNCTION_V2(Build.current().isSnapshot()),
1216+
KNN_FUNCTION_V3(Build.current().isSnapshot()),
12171217

12181218
/**
12191219
* Support for the LIKE operator with a list of wildcards.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
481481
def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"),
482482
def(Score.class, uni(Score::new), Score.NAME),
483483
def(Term.class, bi(Term::new), "term"),
484-
def(Knn.class, Knn::new, "knn"),
484+
def(Knn.class, quad(Knn::new), "knn"),
485485
def(StGeohash.class, StGeohash::new, "st_geohash"),
486486
def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"),
487487
def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"),
@@ -1208,4 +1208,8 @@ private static <T extends Function> TernaryBuilder<T> tri(TernaryBuilder<T> func
12081208
return function;
12091209
}
12101210

1211+
private static <T extends Function> QuaternaryBuilder<T> quad(QuaternaryBuilder<T> function) {
1212+
return function;
1213+
}
1214+
12111215
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,20 +166,19 @@ public boolean equals(Object obj) {
166166

167167
@Override
168168
public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
169-
// In isolation, full text functions are pushable to source. We check if there are no disjunctions in Or conditions
170169
return Translatable.YES;
171170
}
172171

173172
@Override
174173
public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
175-
return queryBuilder != null ? new TranslationAwareExpressionQuery(source(), queryBuilder) : translate(handler);
174+
return queryBuilder != null ? new TranslationAwareExpressionQuery(source(), queryBuilder) : translate(pushdownPredicates, handler);
176175
}
177176

178177
public QueryBuilder queryBuilder() {
179178
return queryBuilder;
180179
}
181180

182-
protected abstract Query translate(TranslatorHandler handler);
181+
protected abstract Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler);
183182

184183
public abstract Expression replaceQueryBuilder(QueryBuilder queryBuilder);
185184

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Kql.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
2323
import org.elasticsearch.xpack.esql.expression.function.Param;
2424
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
25+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
2526
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
2627
import org.elasticsearch.xpack.esql.querydsl.query.KqlQuery;
2728

@@ -93,7 +94,7 @@ protected NodeInfo<? extends Expression> info() {
9394
}
9495

9596
@Override
96-
protected Query translate(TranslatorHandler handler) {
97+
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
9798
return new KqlQuery(source(), Objects.toString(queryAsObject()));
9899
}
99100

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
3636
import org.elasticsearch.xpack.esql.expression.function.Param;
3737
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
38+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
3839
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
3940
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
4041
import org.elasticsearch.xpack.esql.querydsl.query.MatchQuery;
@@ -423,7 +424,7 @@ public Object queryAsObject() {
423424
}
424425

425426
@Override
426-
protected Query translate(TranslatorHandler handler) {
427+
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
427428
var fieldAttribute = fieldAsFieldAttribute();
428429
Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument");
429430
String fieldName = getNameFromFieldAttribute(fieldAttribute);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MatchPhrase.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
3333
import org.elasticsearch.xpack.esql.expression.function.Param;
3434
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
35+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
3536
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
3637
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
3738
import org.elasticsearch.xpack.esql.querydsl.query.MatchPhraseQuery;
@@ -278,7 +279,7 @@ public Object queryAsObject() {
278279
}
279280

280281
@Override
281-
protected Query translate(TranslatorHandler handler) {
282+
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
282283
var fieldAttribute = fieldAsFieldAttribute();
283284
Check.notNull(fieldAttribute, "MatchPhrase must have a field attribute as the first argument");
284285
String fieldName = getNameFromFieldAttribute(fieldAttribute);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/MultiMatch.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
3232
import org.elasticsearch.xpack.esql.expression.function.Param;
3333
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
34+
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
3435
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
3536
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
3637
import org.elasticsearch.xpack.esql.querydsl.query.MultiMatchQuery;
@@ -335,7 +336,7 @@ protected NodeInfo<? extends Expression> info() {
335336
}
336337

337338
@Override
338-
protected Query translate(TranslatorHandler handler) {
339+
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
339340
Map<String, Float> fieldsWithBoost = new HashMap<>();
340341
for (Expression field : fields) {
341342
var fieldAttribute = Match.fieldAsFieldAttribute(field);

0 commit comments

Comments
 (0)