Skip to content

Commit 2641cb3

Browse files
Handle KNN function
1 parent fdf54dc commit 2641cb3

File tree

7 files changed

+135
-18
lines changed

7 files changed

+135
-18
lines changed

x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expression.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ public boolean foldable() {
8585
return false;
8686
}
8787

88+
/**
89+
* Whether the expression can be evaluated partially, for example we might be able to fold and expression in an argument of the function but not remove the function completely
90+
*/
91+
public boolean partiallyFoldable() {
92+
return false;
93+
}
94+
8895
/**
8996
* Evaluate this expression statically to a constant. It is an error to call
9097
* this if {@link #foldable} returns false.
@@ -95,6 +102,10 @@ public Object fold(FoldContext ctx) {
95102
throw new QlIllegalArgumentException("Should not fold expression");
96103
}
97104

105+
public Expression partiallyFold(FoldContext ctx) {
106+
throw new QlIllegalArgumentException("Should not fold expression");
107+
}
108+
98109
public abstract Nullability nullable();
99110

100111
// the references/inputs/leaves of the expression tree

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/FunctionUtils.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ public static Integer limitValue(Expression limitField, String sourceText) {
3232
*/
3333
public static Expression.TypeResolution resolveTypeLimit(Expression limitField, String sourceText) {
3434
if (limitField == null) {
35-
return new Expression.TypeResolution(format(null, "Limit must be a constant integer in [{}], found [{}]", sourceText, limitField));
35+
return new Expression.TypeResolution(
36+
format(null, "Limit must be a constant integer in [{}], found [{}]", sourceText, limitField)
37+
);
3638
}
3739
if (limitField instanceof Literal literal) {
3840
if (literal.value() == null) {
@@ -47,6 +49,7 @@ public static Expression.TypeResolution resolveTypeLimit(Expression limitField,
4749
}
4850
return Expression.TypeResolution.TYPE_RESOLVED;
4951
}
52+
5053
public static void postOptimizationVerificationLimit(Failures failures, Expression limitField, String sourceText) {
5154
if (limitField == null) {
5255
failures.add(fail(limitField, "Limit must be a constant integer in [{}], found [{}]", sourceText, limitField));
@@ -62,6 +65,7 @@ public static void postOptimizationVerificationLimit(Failures failures, Expressi
6265
failures.add(fail(limitField, "Limit must be a constant integer in [{}], found [{}]", sourceText, limitField));
6366
}
6467
}
68+
6569
public static Expression.TypeResolution resolveTypeQuery(Expression queryField, String sourceText) {
6670
if (queryField == null) {
6771
return new Expression.TypeResolution(format(null, "Query must be a valid string in [{}], found [{}]", sourceText, queryField));
@@ -83,16 +87,15 @@ public static void postOptimizationVerificationQuery(Failures failures, Expressi
8387
if (queryField instanceof Literal literal) {
8488
String value = BytesRefs.toString(literal.value());
8589
if (value == null) {
86-
failures.add(
87-
fail(queryField, "Invalid query value in [{}], found [{}]", sourceText, value)
88-
);
90+
failures.add(fail(queryField, "Invalid query value in [{}], found [{}]", sourceText, value));
8991
}
9092
} else {
9193
// it is expected that the expression is a literal after folding
9294
// we fail if it is not a literal
9395
failures.add(fail(queryField, "Query must be a valid string in [{}], found [{}]", sourceText, queryField));
9496
}
9597
}
98+
9699
public static Object queryAsObject(Expression queryField, String sourceText) {
97100
if (queryField instanceof Literal literal) {
98101
return BytesRefs.toString(literal.value());

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ private TypeResolution resolveField() {
312312
}
313313

314314
private TypeResolution resolveQuery() {
315-
TypeResolution result = isType(
315+
TypeResolution result = isType(
316316
query(),
317317
QUERY_DATA_TYPES::contains,
318318
sourceText(),

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/Knn.java

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77

88
package org.elasticsearch.xpack.esql.expression.function.vector;
99

10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
1012
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1113
import org.elasticsearch.common.io.stream.StreamInput;
1214
import org.elasticsearch.common.io.stream.StreamOutput;
1315
import org.elasticsearch.index.query.QueryBuilder;
16+
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
1417
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
1518
import org.elasticsearch.xpack.esql.common.Failures;
1619
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
1720
import org.elasticsearch.xpack.esql.core.expression.Expression;
1821
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
22+
import org.elasticsearch.xpack.esql.core.expression.Literal;
1923
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
2024
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
2125
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
@@ -46,6 +50,7 @@
4650
import java.util.function.BiConsumer;
4751

4852
import static java.util.Map.entry;
53+
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
4954
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
5055
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
5156
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
@@ -56,13 +61,14 @@
5661
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
5762
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression;
5863
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
59-
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable;
6064
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
6165
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
6266
import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT;
6367
import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
68+
import static org.elasticsearch.xpack.esql.expression.function.FunctionUtils.resolveTypeQuery;
6469

6570
public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction, PostAnalysisPlanVerificationAware {
71+
private final Logger log = LogManager.getLogger(getClass());
6672

6773
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);
6874

@@ -189,9 +195,16 @@ private TypeResolution resolveField() {
189195
}
190196

191197
private TypeResolution resolveQuery() {
192-
return isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector").and(
193-
isNotNullAndFoldable(query(), sourceText(), SECOND)
194-
);
198+
TypeResolution result = isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector")
199+
.and(isNotNull(query(), sourceText(), SECOND));
200+
if (result.unresolved()) {
201+
return result;
202+
}
203+
result = resolveTypeQuery(query(), sourceText());
204+
if (result.equals(TypeResolution.TYPE_RESOLVED) == false) {
205+
return result;
206+
}
207+
return TypeResolution.TYPE_RESOLVED;
195208
}
196209

197210
private TypeResolution resolveK() {
@@ -235,19 +248,59 @@ private Map<String, Object> knnQueryOptions() throws InvalidArgumentException {
235248
return matchOptions;
236249
}
237250

251+
@Override
252+
public boolean partiallyFoldable() {
253+
return true;
254+
}
255+
256+
@Override
257+
public Expression partiallyFold(FoldContext ctx) {
258+
if (k instanceof Literal) {
259+
// already folded, return self
260+
return this;
261+
}
262+
Object foldedK = k.fold(ctx);
263+
if (foldedK instanceof Number == false) {
264+
throw new EsqlIllegalArgumentException(format(null, "K value must be a constant integer in [{}], found [{}]", source(), k()));
265+
}
266+
List<Expression> newChildren = new ArrayList<>(this.children());
267+
newChildren.set(2, new Literal(source(), foldedK, INTEGER));
268+
Expression replaced = this.replaceChildren(newChildren);
269+
log.error("Partially folded knn function [{}] with k value [{}]", replaced, foldedK);
270+
return replaced;
271+
}
272+
273+
@Override
274+
public List<Number> queryAsObject() {
275+
// we need to check that we got a list and every element in the list is a number
276+
Expression query = query();
277+
if (query instanceof Literal literal) {
278+
@SuppressWarnings("unchecked")
279+
List<Number> result = ((List<Number>) literal.value());
280+
return result;
281+
}
282+
throw new EsqlIllegalArgumentException(format(null, "Query value must be a list of numbers in [{}], found [{}]", source(), query));
283+
}
284+
285+
int getKIntValue() {
286+
if (k() instanceof Literal literal) {
287+
return (int) (Number) literal.value();
288+
}
289+
throw new EsqlIllegalArgumentException(format(null, "K value must be a constant integer in [{}], found [{}]", source(), k()));
290+
}
291+
238292
@Override
239293
protected Query translate(TranslatorHandler handler) {
240294
var fieldAttribute = Match.fieldAsFieldAttribute(field());
241295

242296
Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument");
243297
String fieldName = getNameFromFieldAttribute(fieldAttribute);
244-
@SuppressWarnings("unchecked")
245-
List<Number> queryFolded = (List<Number>) query().fold(FoldContext.small() /* TODO remove me */);
298+
List<Number> queryFolded = queryAsObject();
246299
float[] queryAsFloats = new float[queryFolded.size()];
247300
for (int i = 0; i < queryFolded.size(); i++) {
248301
queryAsFloats[i] = queryFolded.get(i).floatValue();
249302
}
250-
int kValue = ((Number) k().fold(FoldContext.small())).intValue();
303+
int kValue = getKIntValue();
251304

252305
Map<String, Object> opts = queryOptions();
253306
opts.put(K_FIELD.getPreferredName(), kValue);
@@ -322,12 +375,13 @@ public boolean equals(Object o) {
322375
Knn knn = (Knn) o;
323376
return Objects.equals(field(), knn.field())
324377
&& Objects.equals(query(), knn.query())
325-
&& Objects.equals(queryBuilder(), knn.queryBuilder());
378+
&& Objects.equals(queryBuilder(), knn.queryBuilder())
379+
&& Objects.equals(k(), knn.k());
326380
}
327381

328382
@Override
329383
public int hashCode() {
330-
return Objects.hash(field(), query(), queryBuilder());
384+
return Objects.hash(field(), query(), queryBuilder(), k());
331385
}
332386

333387
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFolding.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ public ConstantFolding() {
1919

2020
@Override
2121
public Expression rule(Expression e, LogicalOptimizerContext ctx) {
22-
return e.foldable() ? Literal.of(ctx.foldCtx(), e) : e;
22+
if (e.foldable()) {
23+
return Literal.of(ctx.foldCtx(), e);
24+
} else if (e.partiallyFoldable()) {
25+
return e.partiallyFold(ctx.foldCtx());
26+
}
27+
return e;
2328
}
2429
}

x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/rule/RuleExecutor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import org.apache.logging.log4j.LogManager;
1010
import org.apache.logging.log4j.Logger;
11+
import org.apache.logging.log4j.core.config.Configurator;
1112
import org.elasticsearch.core.TimeValue;
1213
import org.elasticsearch.xpack.ql.tree.Node;
1314
import org.elasticsearch.xpack.ql.tree.NodeUtils;
@@ -137,6 +138,7 @@ protected final TreeType execute(TreeType plan) {
137138
}
138139

139140
protected final ExecutionInfo executeWithInfo(TreeType plan) {
141+
Configurator.setLevel(log.getName(), org.apache.logging.log4j.Level.TRACE);
140142
TreeType currentPlan = plan;
141143

142144
long totalDuration = 0;

x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/230_folding.yml

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,23 @@ setup:
2424
type: long
2525
name:
2626
type: keyword
27+
image_vector:
28+
type: dense_vector
29+
dims: 3
30+
index: true
31+
similarity: l2_norm
32+
2733
- do:
2834
bulk:
2935
index: employees
3036
refresh: true
3137
body:
3238
- { "index": { } }
33-
- { "hire_date": "2020-01-01", "salary_change": 100.5, "salary": 50000, "salary_change_long": 100, "name": "Alice Smith" }
39+
- { "hire_date": "2020-01-01", "salary_change": 100.5, "salary": 50000, "salary_change_long": 100, "name": "Alice Smith", "image_vector": [ 0.1, 0.2, 0.3 ] }
3440
- { "index": { } }
35-
- { "hire_date": "2021-01-01", "salary_change": 200.5, "salary": 60000, "salary_change_long": 200, "name": "Bob Johnson" }
41+
- { "hire_date": "2021-01-01", "salary_change": 200.5, "salary": 60000, "salary_change_long": 200, "name": "Bob Johnson", "image_vector": [ 0.4, 0.5, 0.6 ] }
3642
- { "index": { } }
37-
- { "hire_date": "2019-01-01", "salary_change": 50.5, "salary": 40000, "salary_change_long": 50, "name": "Charlie Smith" }
43+
- { "hire_date": "2019-01-01", "salary_change": 50.5, "salary": 40000, "salary_change_long": 50, "name": "Charlie Smith", "image_vector": [ 0.7, 0.8, 0.9 ] }
3844

3945
---
4046
Top function with constant folding:
@@ -279,3 +285,39 @@ Foldable query using QSTR on name but with non-foldable expression:
279285
- match: { error.type: "verification_exception" }
280286
- contains: { error.reason: "Query must be a valid string in [QSTR(CONCAT(name, \"Bob\"))], found [CONCAT(name, \"Bob\")]" }
281287

288+
---
289+
290+
Foldable query using KNN on image_vector:
291+
- do:
292+
esql.query:
293+
body:
294+
query: |
295+
FROM employees
296+
| WHERE KNN(image_vector, [0.4, 0.5, 0.9], 1 + 1)
297+
| KEEP hire_date, salary, salary_change, salary_change_long, name, image_vector
298+
| LIMIT 5
299+
- match: { columns.0.name: "hire_date" }
300+
- match: { columns.1.name: "salary" }
301+
- match: { columns.2.name: "salary_change" }
302+
- match: { columns.3.name: "salary_change_long" }
303+
- match: { columns.4.name: "name" }
304+
- match: { columns.5.name: "image_vector" }
305+
- length: { values: 2 }
306+
# The closest vectors to [0.4, 0.5, 0.6] are Bob Johnson and Charlie Smith
307+
- match: { values.0.4: "Bob Johnson" }
308+
- match: { values.1.4: "Charlie Smith" }
309+
310+
---
311+
312+
Foldable query using KNN on image_vector but with non-foldable expression:
313+
- do:
314+
catch: bad_request
315+
esql.query:
316+
body:
317+
query: |
318+
FROM employees
319+
| WHERE KNN(image_vector, [0.4, 0.5, 0.9], 1+salary)
320+
| KEEP hire_date, salary, salary_change, salary_change_long, name, image_vector
321+
| LIMIT 5
322+
- match: { error.type: "verification_exception" }
323+
- contains: { error.reason: "third argument of [KNN(image_vector, [0.4, 0.5, 0.9], 1+salary)] must be a constant, received [1+salary]" }

0 commit comments

Comments
 (0)