Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b7a99fc
Add prefilters to Knn
carlosdelest Jul 9, 2025
61a1e4b
Add logical plan optimizer rule to add prefilters
carlosdelest Jul 9, 2025
fc0e54b
Adds filters to KnnVectorQueryBuilder when translating
carlosdelest Jul 9, 2025
6a52b63
Update KNN capability name
carlosdelest Jul 9, 2025
a73973f
Add tests
carlosdelest Jul 9, 2025
ef12ee5
Avoid infinite loop with multiple knn expressions, add tests
carlosdelest Jul 9, 2025
616b725
Add test for not pushing down knn when there are non-pushable prefilters
carlosdelest Jul 9, 2025
3ae0e9b
Execute premapper after logical optimisation, to allow query builders…
carlosdelest Jul 10, 2025
50d6e19
Fix generating QueryBuilder prefilters for knn when they cannot be pu…
carlosdelest Jul 10, 2025
32715ce
Update CSV tests now that knn uses prefilters
carlosdelest Jul 10, 2025
7ebd1d9
Merge remote-tracking branch 'origin/main' into esql-knn-prefilter
carlosdelest Jul 10, 2025
9455f77
Merge branch 'main' into esql-knn-prefilter
carlosdelest Jul 10, 2025
62b0d1c
Remove TransportVersion as knn is not released yet
carlosdelest Jul 11, 2025
82fd2cc
Merge remote-tracking branch 'origin/main' into esql-knn-prefilter
carlosdelest Jul 11, 2025
bfdd227
Merge remote-tracking branch 'carlosdelest/esql-knn-prefilter' into e…
carlosdelest Jul 11, 2025
b347bba
Add queryBuilder and filterExpresions to NodeInfo
carlosdelest Jul 11, 2025
1083764
Remove unnecessary method
carlosdelest Jul 11, 2025
6e4c1a9
Fix test for node info
carlosdelest Jul 11, 2025
91d4a0a
Fix typo
carlosdelest Jul 11, 2025
a99e75f
Add multiple filters test
carlosdelest Jul 11, 2025
4c6bb4d
Add negations tests
carlosdelest Jul 11, 2025
47a6d99
[CI] Auto commit changes from spotless
Jul 11, 2025
ac7c3bc
knn functions won't appear as prefilters for other knn functions
carlosdelest Jul 14, 2025
32ada3d
Merge remote-tracking branch 'carlosdelest/esql-knn-prefilter' into e…
carlosdelest Jul 14, 2025
939705c
Merge branch 'main' into esql-knn-prefilter
carlosdelest Jul 14, 2025
dd29fec
Fix test
carlosdelest Jul 14, 2025
b03f2a3
Fix test
carlosdelest Jul 14, 2025
20fe8fe
Merge branch 'main' into esql-knn-prefilter
carlosdelest Jul 15, 2025
89c86e4
Merge remote-tracking branch 'origin/main' into esql-knn-prefilter
carlosdelest Jul 16, 2025
8a550a1
Fix merge
carlosdelest Jul 16, 2025
ea2fddd
Merge remote-tracking branch 'carlosdelest/esql-knn-prefilter' into e…
carlosdelest Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# top-n query at the shard level

knnSearch
required_capability: knn_function_v2
required_capability: knn_function_v3

// tag::knn-function[]
from colors metadata _score
Expand All @@ -30,7 +30,7 @@ chartreuse | [127.0, 255.0, 0.0]
;

knnSearchWithSimilarityOption
required_capability: knn_function_v2
required_capability: knn_function_v3

from colors metadata _score
| where knn(rgb_vector, [255,192,203], 140, {"similarity": 40})
Expand All @@ -46,14 +46,13 @@ wheat | [245.0, 222.0, 179.0]
;

knnHybridSearch
required_capability: knn_function_v2
required_capability: knn_function_v3

from colors metadata _score
| where match(color, "blue") or knn(rgb_vector, [65,105,225], 140)
| where match(color, "blue") or knn(rgb_vector, [65,105,225], 10)
| where primary == true
| sort _score desc, color asc
| keep color, rgb_vector
| limit 10
;

color:text | rgb_vector:dense_vector
Expand All @@ -68,21 +67,45 @@ red | [255.0, 0.0, 0.0]
yellow | [255.0, 255.0, 0.0]
;

knnWithMultipleFunctions
required_capability: knn_function_v2
knnWithPrefilter
required_capability: knn_function_v3

from colors metadata _score
| where knn(rgb_vector, [128,128,0], 140) and match(color, "olive")
| where knn(rgb_vector, [128,128,0], 10) and (match(color, "olive") or match(color, "green"))
| sort _score desc, color asc
| keep color, rgb_vector
;

color:text | rgb_vector:dense_vector
olive | [128.0, 128.0, 0.0]
green | [0.0, 128.0, 0.0]
;

knnWithNegatedPrefilter
required_capability: knn_function_v3

from colors metadata _score
| where knn(rgb_vector, [128,128,0], 10) and not (match(color, "olive") or match(color, "chocolate"))
| sort _score desc, color asc
| keep color, rgb_vector
| LIMIT 10
;

color:text | rgb_vector:dense_vector
sienna | [160.0, 82.0, 45.0]
peru | [205.0, 133.0, 63.0]
golden rod | [218.0, 165.0, 32.0]
brown | [165.0, 42.0, 42.0]
firebrick | [178.0, 34.0, 34.0]
chartreuse | [127.0, 255.0, 0.0]
gray | [128.0, 128.0, 128.0]
green | [0.0, 128.0, 0.0]
maroon | [128.0, 0.0, 0.0]
orange | [255.0, 165.0, 0.0]
;

knnAfterKeep
required_capability: knn_function_v2
required_capability: knn_function_v3

from colors metadata _score
| keep rgb_vector, color, _score
Expand All @@ -101,7 +124,7 @@ rgb_vector:dense_vector
;

knnAfterDrop
required_capability: knn_function_v2
required_capability: knn_function_v3

from colors metadata _score
| drop primary
Expand All @@ -120,7 +143,7 @@ lime | [0.0, 255.0, 0.0]
;

knnAfterEval
required_capability: knn_function_v2
required_capability: knn_function_v3

from colors metadata _score
| eval composed_name = locate(color, " ") > 0
Expand All @@ -139,14 +162,12 @@ golden rod | true
;

knnWithConjunction
required_capability: knn_function_v2
required_capability: knn_function_v3

# TODO We need kNN prefiltering here so we get more candidates that pass the filter
from colors metadata _score
| where knn(rgb_vector, [255,255,238], 140) and hex_code like "#FFF*"
| where knn(rgb_vector, [255,255,238], 10) and hex_code like "#FFF*"
| sort _score desc, color asc
| keep color, hex_code, rgb_vector
| limit 10
;

color:text | hex_code:keyword | rgb_vector:dense_vector
Expand All @@ -160,11 +181,10 @@ yellow | #FFFF00 | [255.0, 255.0, 0.0]
;

knnWithDisjunctionAndFiltersConjunction
required_capability: knn_function_v2
required_capability: knn_function_v3

# TODO We need kNN prefiltering here so we get more candidates that pass the filter
from colors metadata _score
| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 140)) and primary == true
| where (knn(rgb_vector, [0,255,255], 140) or knn(rgb_vector, [128, 0, 255], 10)) and primary == true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using prefilter means that we no longer need to retrieve all results, but can use k as inteneded as conjunction is pushed down as a prefilter now

| keep color, rgb_vector, _score
| sort _score desc, color asc
| drop _score
Expand All @@ -183,8 +203,31 @@ red | [255.0, 0.0, 0.0]
yellow | [255.0, 255.0, 0.0]
;

knnWithNegationsAndFiltersConjunction
required_capability: knn_function_v3

from colors metadata _score
| where (knn(rgb_vector, [0,255,255], 140) and not(primary == true and match(color, "blue")))
| sort _score desc, color asc
| keep color, rgb_vector
| limit 10
;

color:text | rgb_vector:dense_vector
cyan | [0.0, 255.0, 255.0]
turquoise | [64.0, 224.0, 208.0]
aqua marine | [127.0, 255.0, 212.0]
teal | [0.0, 128.0, 128.0]
silver | [192.0, 192.0, 192.0]
gray | [128.0, 128.0, 128.0]
gainsboro | [220.0, 220.0, 220.0]
thistle | [216.0, 191.0, 216.0]
lavender | [230.0, 230.0, 250.0]
azure | [240.0, 255.0, 255.0]
;

knnWithNonPushableConjunction
required_capability: knn_function_v2
required_capability: knn_function_v3

from colors metadata _score
| eval composed_name = locate(color, " ") > 0
Expand All @@ -208,7 +251,7 @@ maroon | false
;

testKnnWithNonPushableDisjunctions
required_capability: knn_function_v2
required_capability: knn_function_v3

from colors metadata _score
| where knn(rgb_vector, [128,128,0], 140, {"similarity": 30}) or length(color) > 10
Expand All @@ -224,7 +267,7 @@ papaya whip
;

testKnnWithNonPushableDisjunctionsOnComplexExpressions
required_capability: knn_function_v2
required_capability: knn_function_v3

from colors metadata _score
| where (knn(rgb_vector, [128,128,0], 140, {"similarity": 70}) and length(color) < 10) or (knn(rgb_vector, [128,0,128], 140, {"similarity": 60}) and primary == false)
Expand All @@ -239,7 +282,7 @@ indigo | false
;

testKnnInStatsNonPushable
required_capability: knn_function_v2
required_capability: knn_function_v3

from colors
| where length(color) < 10
Expand All @@ -251,7 +294,7 @@ c: long
;

testKnnInStatsWithGrouping
required_capability: knn_function_v2
required_capability: knn_function_v3
required_capability: full_text_functions_in_stats_where

from colors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,29 @@ public void testKnnNonPushedDown() {
}
}

public void testKnnWithPrefilters() {
float[] queryVector = new float[numDims];
Arrays.fill(queryVector, 1.0f);

// We retrieve 5 from knn, but must be prefiltered with id > 5 or no result will be returned as it would be post-filtered
var query = String.format(Locale.ROOT, """
FROM test METADATA _score
| WHERE knn(vector, %s, 5) AND id > 5
| KEEP id, floats, _score, vector
| SORT _score DESC
| LIMIT 5
""", Arrays.toString(queryVector));

try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "floats", "_score", "vector"));
assertColumnTypes(resp.columns(), List.of("integer", "double", "double", "dense_vector"));

List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
// K = 5, 1 more for every id > 10
assertEquals(5, valuesList.size());
}
}

public void testKnnWithLookupJoin() {
float[] queryVector = new float[numDims];
Arrays.fill(queryVector, 1.0f);
Expand All @@ -136,7 +159,7 @@ public void testKnnWithLookupJoin() {

@Before
public void setup() throws IOException {
assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled());
assumeTrue("Needs KNN support", EsqlCapabilities.Cap.KNN_FUNCTION_V3.isEnabled());

var indexName = "test";
var client = client().admin().indices();
Expand All @@ -163,7 +186,7 @@ public void setup() throws IOException {
var createRequest = client.prepareCreate(indexName).setMapping(mapping).setSettings(settingsBuilder.build());
assertAcked(createRequest);

numDocs = randomIntBetween(10, 20);
numDocs = randomIntBetween(15, 25);
numDims = randomIntBetween(3, 10);
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
float value = 0.0f;
Expand Down Expand Up @@ -202,6 +225,5 @@ private void createAndPopulateLookupIndex(IndicesAdminClient client, String look

var createRequest = client.prepareCreate(lookupIndexName).setMapping(mapping).setSettings(settingsBuilder.build());
assertAcked(createRequest);

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,7 @@ public enum Cap {
/**
* Support knn function
*/
KNN_FUNCTION_V2(Build.current().isSnapshot()),
KNN_FUNCTION_V3(Build.current().isSnapshot()),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nitpick - I wonder if we could create an alias that this could reference (so when we go to V4 we only need to change one variable when testing for test compatibility)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide an example of what it would look like?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like KNN_FUNCTION_CURRENT = EsqlCapabilities.KNN_FUNCTION_V3?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sue if that works with CSV tests 🤔 . I'd say let's keep this as it's an established pattern in ES|QL.

Maybe I'll give it a try for the next one and see what it looks like from a changes perspective 👍


/**
* Support for the LIKE operator with a list of wildcards.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
def(FirstOverTime.class, uni(FirstOverTime::new), "first_over_time"),
def(Score.class, uni(Score::new), Score.NAME),
def(Term.class, bi(Term::new), "term"),
def(Knn.class, Knn::new, "knn"),
def(Knn.class, quad(Knn::new), "knn"),
def(StGeohash.class, StGeohash::new, "st_geohash"),
def(StGeohashToLong.class, StGeohashToLong::new, "st_geohash_to_long"),
def(StGeohashToString.class, StGeohashToString::new, "st_geohash_to_string"),
Expand Down Expand Up @@ -1208,4 +1208,8 @@ private static <T extends Function> TernaryBuilder<T> tri(TernaryBuilder<T> func
return function;
}

private static <T extends Function> QuaternaryBuilder<T> quad(QuaternaryBuilder<T> function) {
return function;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -166,20 +166,19 @@ public boolean equals(Object obj) {

@Override
public Translatable translatable(LucenePushdownPredicates pushdownPredicates) {
// In isolation, full text functions are pushable to source. We check if there are no disjunctions in Or conditions
return Translatable.YES;
}

@Override
public Query asQuery(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
return queryBuilder != null ? new TranslationAwareExpressionQuery(source(), queryBuilder) : translate(handler);
return queryBuilder != null ? new TranslationAwareExpressionQuery(source(), queryBuilder) : translate(pushdownPredicates, handler);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added pushdownPredicates to translate method, as we need to evaluate the prefilters with them in order to decide

}

public QueryBuilder queryBuilder() {
return queryBuilder;
}

protected abstract Query translate(TranslatorHandler handler);
protected abstract Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be easier to have two signatures for this, for a smaller file change count?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's more confusing. LucenePushdownPredicates are part of the translation process even if they were not used until now.


public abstract Expression replaceQueryBuilder(QueryBuilder queryBuilder);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
import org.elasticsearch.xpack.esql.querydsl.query.KqlQuery;

Expand Down Expand Up @@ -93,7 +94,7 @@ protected NodeInfo<? extends Expression> info() {
}

@Override
protected Query translate(TranslatorHandler handler) {
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
return new KqlQuery(source(), Objects.toString(queryAsObject()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
import org.elasticsearch.xpack.esql.querydsl.query.MatchQuery;
Expand Down Expand Up @@ -423,7 +424,7 @@ public Object queryAsObject() {
}

@Override
protected Query translate(TranslatorHandler handler) {
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
var fieldAttribute = fieldAsFieldAttribute();
Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument");
String fieldName = getNameFromFieldAttribute(fieldAttribute);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
import org.elasticsearch.xpack.esql.querydsl.query.MatchPhraseQuery;
Expand Down Expand Up @@ -278,7 +279,7 @@ public Object queryAsObject() {
}

@Override
protected Query translate(TranslatorHandler handler) {
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
var fieldAttribute = fieldAsFieldAttribute();
Check.notNull(fieldAttribute, "MatchPhrase must have a field attribute as the first argument");
String fieldName = getNameFromFieldAttribute(fieldAttribute);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;
import org.elasticsearch.xpack.esql.querydsl.query.MultiMatchQuery;
Expand Down Expand Up @@ -335,7 +336,7 @@ protected NodeInfo<? extends Expression> info() {
}

@Override
protected Query translate(TranslatorHandler handler) {
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
Map<String, Float> fieldsWithBoost = new HashMap<>();
for (Expression field : fields) {
var fieldAttribute = Match.fieldAsFieldAttribute(field);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.xpack.esql.expression.function.OptionalArgument;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.optimizer.rules.physical.local.LucenePushdownPredicates;
import org.elasticsearch.xpack.esql.planner.TranslatorHandler;

import java.io.IOException;
Expand Down Expand Up @@ -345,7 +346,7 @@ protected NodeInfo<? extends Expression> info() {
}

@Override
protected Query translate(TranslatorHandler handler) {
protected Query translate(LucenePushdownPredicates pushdownPredicates, TranslatorHandler handler) {
return new QueryStringQuery(source(), Objects.toString(queryAsObject()), Map.of(), queryStringOptions());
}

Expand Down
Loading
Loading