Skip to content

Commit e76c2a6

Browse files
authored
ESQL - KNN function uses LIMIT for K, transforms to exact search when not pushed down (elastic#132944)
1 parent 22af544 commit e76c2a6

File tree

25 files changed

+572
-230
lines changed

25 files changed

+572
-230
lines changed

docs/reference/query-languages/esql/_snippets/functions/examples/knn.md

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/_snippets/functions/functionNamedParams/knn.md

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/_snippets/functions/parameters/knn.md

Lines changed: 0 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/images/functions/knn.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/reference/query-languages/esql/kibana/definition/functions/knn.json

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/query-languages/esql/kibana/docs/functions/knn.md

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneQueryEvaluator.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,16 @@ protected LuceneQueryEvaluator(BlockFactory blockFactory, ShardConfig[] shards)
6060
}
6161

6262
public Block executeQuery(Page page) {
63-
// Lucene based operators retrieve DocVectors as first block
64-
Block block = page.getBlock(0);
65-
assert block instanceof DocBlock : "LuceneQueryExpressionEvaluator expects DocBlock as input";
66-
DocVector docs = (DocVector) block.asVector();
63+
// Search for DocVector block
64+
Block docBlock = null;
65+
for (int i = 0; i < page.getBlockCount(); i++) {
66+
if (page.getBlock(i) instanceof DocBlock) {
67+
docBlock = page.getBlock(i);
68+
break;
69+
}
70+
}
71+
assert docBlock != null : "LuceneQueryExpressionEvaluator expects a DocBlock";
72+
DocVector docs = (DocVector) docBlock.asVector();
6773
try {
6874
if (docs.singleSegmentNonDecreasing()) {
6975
return evalSingleSegmentNonDecreasing(docs);

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/ScoreOperator.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.elasticsearch.compute.data.Block;
1111
import org.elasticsearch.compute.data.BlockFactory;
12-
import org.elasticsearch.compute.data.DocVector;
1312
import org.elasticsearch.compute.data.DoubleBlock;
1413
import org.elasticsearch.compute.data.DoubleVector;
1514
import org.elasticsearch.compute.data.Page;
@@ -46,9 +45,9 @@ public ScoreOperator(BlockFactory blockFactory, ExpressionScorer scorer, int sco
4645

4746
@Override
4847
protected Page process(Page page) {
49-
assert page.getBlockCount() >= 2 : "Expected at least 2 blocks, got " + page.getBlockCount();
50-
assert page.getBlock(0).asVector() instanceof DocVector : "Expected a DocVector, got " + page.getBlock(0).asVector();
51-
assert page.getBlock(1).asVector() instanceof DoubleVector : "Expected a DoubleVector, got " + page.getBlock(1).asVector();
48+
assert page.getBlockCount() > scoreBlockPosition : "Expected to get a score block in position " + scoreBlockPosition;
49+
assert page.getBlock(scoreBlockPosition).asVector() instanceof DoubleVector
50+
: "Expected a DoubleVector as a score block, got " + page.getBlock(scoreBlockPosition).asVector();
5251

5352
Block[] blocks = new Block[page.getBlockCount()];
5453
for (int i = 0; i < page.getBlockCount(); i++) {

x-pack/plugin/esql/qa/testFixtures/src/main/resources/knn-function.csv-spec

Lines changed: 87 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
# top-n query at the shard level
44

55
knnSearch
6-
required_capability: knn_function_v3
6+
required_capability: knn_function_v4
77

88
// tag::knn-function[]
99
from colors metadata _score
10-
| where knn(rgb_vector, [0, 120, 0], 10)
10+
| where knn(rgb_vector, [0, 120, 0])
1111
| sort _score desc, color asc
1212
// end::knn-function[]
1313
| keep color, rgb_vector
@@ -30,10 +30,10 @@ chartreuse | [127.0, 255.0, 0.0]
3030
;
3131

3232
knnSearchWithSimilarityOption
33-
required_capability: knn_function_v3
33+
required_capability: knn_function_v4
3434

3535
from colors metadata _score
36-
| where knn(rgb_vector, [255,192,203], 140, {"similarity": 40})
36+
| where knn(rgb_vector, [255,192,203], {"similarity": 40})
3737
| sort _score desc, color asc
3838
| keep color, rgb_vector
3939
;
@@ -46,13 +46,14 @@ wheat | [245.0, 222.0, 179.0]
4646
;
4747

4848
knnHybridSearch
49-
required_capability: knn_function_v3
49+
required_capability: knn_function_v4
5050

5151
from colors metadata _score
52-
| where match(color, "blue") or knn(rgb_vector, [65,105,225], 10)
52+
| where match(color, "blue") or knn(rgb_vector, [65,105,225])
5353
| where primary == true
5454
| sort _score desc, color asc
5555
| keep color, rgb_vector
56+
| limit 10
5657
;
5758

5859
color:text | rgb_vector:dense_vector
@@ -68,10 +69,10 @@ yellow | [255.0, 255.0, 0.0]
6869
;
6970

7071
knnWithPrefilter
71-
required_capability: knn_function_v3
72+
required_capability: knn_function_v4
7273

7374
from colors
74-
| where knn(rgb_vector, [120,180,0], 10) and (match(color, "olive") or match(color, "green"))
75+
| where knn(rgb_vector, [120,180,0]) and (match(color, "olive") or match(color, "green"))
7576
| sort color asc
7677
| keep color
7778
;
@@ -82,10 +83,10 @@ olive
8283
;
8384

8485
knnWithNegatedPrefilter
85-
required_capability: knn_function_v3
86+
required_capability: knn_function_v4
8687

8788
from colors metadata _score
88-
| where knn(rgb_vector, [128,128,0], 10) and not (match(color, "olive") or match(color, "chocolate"))
89+
| where knn(rgb_vector, [128,128,0]) and not (match(color, "olive") or match(color, "chocolate"))
8990
| sort _score desc, color asc
9091
| keep color, rgb_vector
9192
| LIMIT 10
@@ -105,11 +106,11 @@ orange | [255.0, 165.0, 0.0]
105106
;
106107

107108
knnAfterKeep
108-
required_capability: knn_function_v3
109+
required_capability: knn_function_v4
109110

110111
from colors metadata _score
111112
| keep rgb_vector, color, _score
112-
| where knn(rgb_vector, [128,255,0], 140)
113+
| where knn(rgb_vector, [128,255,0])
113114
| sort _score desc, color asc
114115
| keep rgb_vector
115116
| limit 5
@@ -124,11 +125,11 @@ rgb_vector:dense_vector
124125
;
125126

126127
knnAfterDrop
127-
required_capability: knn_function_v3
128+
required_capability: knn_function_v4
128129

129130
from colors metadata _score
130131
| drop primary
131-
| where knn(rgb_vector, [128,250,0], 140)
132+
| where knn(rgb_vector, [128,250,0])
132133
| sort _score desc, color asc
133134
| keep color, rgb_vector
134135
| limit 5
@@ -143,11 +144,11 @@ lime | [0.0, 255.0, 0.0]
143144
;
144145

145146
knnAfterEval
146-
required_capability: knn_function_v3
147+
required_capability: knn_function_v4
147148

148149
from colors metadata _score
149150
| eval composed_name = locate(color, " ") > 0
150-
| where knn(rgb_vector, [128,128,0], 140)
151+
| where knn(rgb_vector, [128,128,0])
151152
| sort _score desc, color asc
152153
| keep color, composed_name
153154
| limit 5
@@ -162,12 +163,13 @@ golden rod | true
162163
;
163164

164165
knnWithConjunction
165-
required_capability: knn_function_v3
166+
required_capability: knn_function_v4
166167

167168
from colors metadata _score
168-
| where knn(rgb_vector, [255,255,238], 10) and hex_code like "#FFF*"
169+
| where knn(rgb_vector, [255,255,238]) and hex_code like "#FFF*"
169170
| sort _score desc, color asc
170171
| keep color, hex_code, rgb_vector
172+
| limit 10
171173
;
172174

173175
color:text | hex_code:keyword | rgb_vector:dense_vector
@@ -181,10 +183,10 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0]
181183
;
182184

183185
knnWithDisjunctionAndFiltersConjunction
184-
required_capability: knn_function_v3
186+
required_capability: knn_function_v4
185187

186188
from colors metadata _score
187-
| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 10)) and primary == true
189+
| where (knn(rgb_vector, [0,255,255]) or knn(rgb_vector, [128, 0, 255])) and primary == true
188190
| keep color, rgb_vector, _score
189191
| sort _score desc, color asc
190192
| drop _score
@@ -204,10 +206,10 @@ yellow | [255.0, 255.0, 0.0]
204206
;
205207

206208
knnWithNegationsAndFiltersConjunction
207-
required_capability: knn_function_v3
209+
required_capability: knn_function_v4
208210

209211
from colors metadata _score
210-
| where (knn(rgb_vector, [0,255,255], 140) and not(primary == true and match(color, "blue")))
212+
| where (knn(rgb_vector, [0,255,255]) and not(primary == true and match(color, "blue")))
211213
| sort _score desc, color asc
212214
| keep color, rgb_vector
213215
| limit 10
@@ -227,11 +229,11 @@ azure | [240.0, 255.0, 255.0]
227229
;
228230

229231
knnWithNonPushableConjunction
230-
required_capability: knn_function_v3
232+
required_capability: knn_function_v4
231233

232234
from colors metadata _score
233235
| eval composed_name = locate(color, " ") > 0
234-
| where knn(rgb_vector, [128,128,0], 140) and composed_name == false
236+
| where knn(rgb_vector, [128,128,0], {"min_candidates": 100}) and composed_name == false
235237
| sort _score desc, color asc
236238
| keep color, composed_name
237239
| limit 10
@@ -251,58 +253,88 @@ maroon | false
251253
;
252254

253255
testKnnWithNonPushableDisjunctions
254-
required_capability: knn_function_v3
256+
required_capability: knn_function_v4
255257

256258
from colors metadata _score
257-
| where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10
259+
| where knn(rgb_vector, [128,128,0]) or length(color) > 10
258260
| sort _score desc, color asc
259-
| keep color
261+
| keep color
262+
| limit 10
260263
;
261264

262265
color:text
263-
olive
264-
aqua marine
265-
lemon chiffon
266-
papaya whip
266+
olive
267+
sienna
268+
chocolate
269+
peru
270+
golden rod
271+
brown
272+
firebrick
273+
chartreuse
274+
gray
275+
green
267276
;
268277

269-
testKnnWithNonPushableDisjunctionsOnComplexExpressions
270-
required_capability: knn_function_v3
278+
testKnnWithNonPushableDisjunctionsAndMinCandidates
279+
required_capability: knn_function_v4
271280

272281
from colors metadata _score
273-
| where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false)
282+
| where (knn(rgb_vector, [128,128,0], {"min_candidates": 2}) and length(color) > 10) or (knn(rgb_vector, [128,0,128], {"min_candidates": 2}) and primary == true)
274283
| sort _score desc, color asc
275284
| keep color, primary
276285
;
277286

278287
color:text | primary:boolean
279-
olive | false
280-
purple | false
281-
indigo | false
282-
;
288+
gray | true
289+
green | true
290+
red | true
291+
black | true
292+
magenta | true
293+
yellow | true
294+
blue | true
295+
aqua marine | false
296+
papaya whip | false
297+
lemon chiffon | false
298+
white | true
299+
cyan | true
300+
;
301+
302+
testKnnWithStats
303+
required_capability: knn_function_v4
283304

284-
testKnnInStatsNonPushable
285-
required_capability: knn_function_v3
286-
287-
from colors
288-
| where length(color) < 10
289-
| stats c = count(*) where knn(rgb_vector, [128,128,255], 140)
305+
from colors metadata _score
306+
| where knn(rgb_vector, [128,128,0])
307+
| sort _score desc, color asc
308+
| limit 15
309+
| stats c = count(*)
290310
;
291311

292-
c: long
293-
50
312+
c:long
313+
15
294314
;
295315

296-
testKnnInStatsWithGrouping
297-
required_capability: knn_function_v3
298-
required_capability: full_text_functions_in_stats_where
316+
testKnnWithRerank
317+
required_capability: knn_function_v4
318+
required_capability: rerank
299319

300-
from colors
301-
| where length(color) < 10
302-
| stats c = count(*) where knn(rgb_vector, [128,128,255], 140) by primary
320+
from colors metadata _score
321+
| where knn(rgb_vector, [100,120,0])
322+
| sort _score desc, color asc
323+
| limit 10
324+
| rerank rerank_score = "deepest blue" ON color WITH { "inference_id" : "test_reranker" }
325+
| sort rerank_score desc, color asc
326+
| keep color
303327
;
304328

305-
c: long | primary: boolean
306-
41 | false
307-
9 | true
329+
color:text
330+
gray
331+
peru
332+
brown
333+
green
334+
olive
335+
maroon
336+
sienna
337+
chocolate
338+
firebrick
339+
golden rod
308340
;

0 commit comments

Comments
 (0)