Skip to content

Commit 82779ec

Browse files
committed
Translate to exact NN when not pushable
1 parent ee3c806 commit 82779ec

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,18 +383,22 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua
383383
ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()];
384384
int i = 0;
385385
for (EsPhysicalOperationProviders.ShardContext shardContext : shardContexts) {
386-
shardConfigs[i++] = new ShardConfig(shardContext.toQuery(queryBuilder()), shardContext.searcher());
386+
shardConfigs[i++] = new ShardConfig(shardContext.toQuery(evaluatorQueryBuilder()), shardContext.searcher());
387387
}
388388
return new LuceneQueryExpressionEvaluator.Factory(shardConfigs);
389389
}
390390

391+
protected QueryBuilder evaluatorQueryBuilder() {
392+
return queryBuilder();
393+
}
394+
391395
@Override
392396
public ScoreOperator.ExpressionScorer.Factory toScorer(ToScorer toScorer) {
393397
List<EsPhysicalOperationProviders.ShardContext> shardContexts = toScorer.shardContexts();
394398
ShardConfig[] shardConfigs = new ShardConfig[shardContexts.size()];
395399
int i = 0;
396400
for (EsPhysicalOperationProviders.ShardContext shardContext : shardContexts) {
397-
shardConfigs[i++] = new ShardConfig(shardContext.toQuery(queryBuilder()), shardContext.searcher());
401+
shardConfigs[i++] = new ShardConfig(shardContext.toQuery(evaluatorQueryBuilder()), shardContext.searcher());
398402
}
399403
return new LuceneQueryScoreEvaluator.Factory(shardConfigs);
400404
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import org.elasticsearch.common.io.stream.StreamInput;
1414
import org.elasticsearch.common.io.stream.StreamOutput;
1515
import org.elasticsearch.index.query.QueryBuilder;
16+
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
17+
import org.elasticsearch.search.vectors.VectorData;
1618
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
1719
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
1820
import org.elasticsearch.xpack.esql.capabilities.TranslationAware;
@@ -275,11 +277,7 @@ protected Query translate(LucenePushdownPredicates pushdownPredicates, Translato
275277

276278
Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument");
277279
String fieldName = getNameFromFieldAttribute(fieldAttribute);
278-
List<Number> queryFolded = queryAsObject();
279-
float[] queryAsFloats = new float[queryFolded.size()];
280-
for (int i = 0; i < queryFolded.size(); i++) {
281-
queryAsFloats[i] = queryFolded.get(i).floatValue();
282-
}
280+
float[] queryAsFloats = queryAsFloats();
283281
int kValue = getKIntValue();
284282

285283
Map<String, Object> opts = queryOptions();
@@ -289,8 +287,8 @@ protected Query translate(LucenePushdownPredicates pushdownPredicates, Translato
289287
for (Expression filterExpression : filterExpressions()) {
290288
if (filterExpression instanceof TranslationAware translationAware) {
291289
// We can only translate filter expressions that are translatable. In case any is not translatable,
292-
// Knn won't be pushed down as it will not be translatable so it's safe not to translate all filters and check them
293-
// when creating an evaluator for the non-pushed down query
290+
// Knn won't be pushed down so it's safe not to translate all filters and check them when creating an evaluator
291+
// for the non-pushed down query
294292
if (translationAware.translatable(pushdownPredicates) == Translatable.YES) {
295293
filterQueries.add(handler.asQuery(pushdownPredicates, filterExpression).toQueryBuilder());
296294
}
@@ -300,6 +298,15 @@ protected Query translate(LucenePushdownPredicates pushdownPredicates, Translato
300298
return new KnnQuery(source(), fieldName, queryAsFloats, opts, filterQueries);
301299
}
302300

301+
private float[] queryAsFloats() {
302+
List<Number> queryFolded = queryAsObject();
303+
float[] queryAsFloats = new float[queryFolded.size()];
304+
for (int i = 0; i < queryFolded.size(); i++) {
305+
queryAsFloats[i] = queryFolded.get(i).floatValue();
306+
}
307+
return queryAsFloats;
308+
}
309+
303310
public Expression withFilters(List<Expression> filterExpressions) {
304311
return new Knn(source(), field(), query(), k(), options(), queryBuilder(), filterExpressions);
305312
}
@@ -312,6 +319,16 @@ private Map<String, Object> queryOptions() throws InvalidArgumentException {
312319
return options;
313320
}
314321

322+
protected QueryBuilder evaluatorQueryBuilder() {
323+
// Either we couldn't push down due to non-pushable filters, or becauses it's part of a disjuncion. Use exact query.
324+
var fieldAttribute = Match.fieldAsFieldAttribute(field());
325+
Check.notNull(fieldAttribute, "Knn must have a field attribute as the first argument");
326+
String fieldName = getNameFromFieldAttribute(fieldAttribute);
327+
Map<String, Object> opts = queryOptions();
328+
329+
return new ExactKnnQueryBuilder(VectorData.fromFloats(queryAsFloats()), fieldName, (Float) opts.get(VECTOR_SIMILARITY_FIELD));
330+
}
331+
315332
@Override
316333
public BiConsumer<LogicalPlan, Failures> postAnalysisPlanVerification() {
317334
return (plan, failures) -> {

0 commit comments

Comments
 (0)