Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
0f080a5
Fix for top
julian-elastic Jul 9, 2025
8a24206
Fix for sample function
julian-elastic Jul 10, 2025
9dd5502
[CI] Auto commit changes from spotless
Jul 10, 2025
67dc79f
Handle Match function
julian-elastic Jul 10, 2025
923dc02
Handle MatchPhrase function
julian-elastic Jul 11, 2025
29cce29
Handle MultiMatch function
julian-elastic Jul 11, 2025
fdf54dc
Handle QueryString function
julian-elastic Jul 11, 2025
2641cb3
Handle KNN function
julian-elastic Jul 11, 2025
8e2c6b4
Merge remote-tracking branch 'origin/foldable' into foldable
julian-elastic Jul 12, 2025
15dfed3
Migrate tests from VerifierTests to 230_folding.yml
julian-elastic Jul 14, 2025
408c8fd
Remove partiallyFoldable code as it is not needed
julian-elastic Jul 14, 2025
f187139
Fix some of the failing UTs
julian-elastic Jul 14, 2025
d66fdfc
Merge branch 'main' into foldable
julian-elastic Jul 15, 2025
d80dd57
[CI] Auto commit changes from spotless
Jul 15, 2025
6f34981
Remove some debugging code
julian-elastic Jul 15, 2025
4247581
Fix failing UTs in old version
julian-elastic Jul 16, 2025
b2213e1
Merge branch 'main' into foldable
julian-elastic Jul 16, 2025
029f96a
Integrate with knn_function_v3
julian-elastic Jul 16, 2025
18e18a2
Fix UT fails
julian-elastic Jul 17, 2025
23359b5
Fix UT fail
julian-elastic Jul 17, 2025
4007272
Merge branch 'main' into foldable
julian-elastic Jul 17, 2025
9537344
Fix merge error
julian-elastic Jul 17, 2025
8e9d34b
Merge branch 'main' into foldable
julian-elastic Jul 18, 2025
20245b6
Fix merge error
julian-elastic Jul 18, 2025
b2e796c
[CI] Auto commit changes from spotless
Jul 18, 2025
cb676a7
Fix UT error
julian-elastic Jul 18, 2025
910c1d0
Merge branch 'main' into foldable
julian-elastic Jul 18, 2025
91f76d6
Remove isNotNullAndFoldable function completely
julian-elastic Jul 18, 2025
b3ee3a0
Update docs/changelog/130944.yaml
julian-elastic Jul 18, 2025
8dfdf1a
Fix UT failures related to trying to get the datatype on unresolved a…
julian-elastic Jul 18, 2025
4087f30
Merge branch 'main' into foldable
julian-elastic Jul 19, 2025
8e11323
Merge branch 'main' into foldable
julian-elastic Jul 21, 2025
89dfb4e
Update docs/changelog/130944.yaml
julian-elastic Jul 21, 2025
55b9246
Fix failing UTs
julian-elastic Jul 21, 2025
94909ec
Fix failing UT
julian-elastic Jul 21, 2025
2444673
Merge branch 'main' into foldable
julian-elastic Jul 21, 2025
1d730a5
Fix UT error
julian-elastic Jul 21, 2025
6714b8b
Update docs/changelog/130944.yaml
julian-elastic Jul 21, 2025
094aa80
Address code review feedback
julian-elastic Jul 23, 2025
50918a4
Fix checkstyle
julian-elastic Jul 23, 2025
5012a2e
Merge branch 'main' into foldable
julian-elastic Jul 23, 2025
09bbf06
Address code review comments
julian-elastic Jul 24, 2025
9fce42b
Merge branch 'main' into foldable
julian-elastic Jul 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ public boolean foldable() {
return false;
}

/**
* Whether the expression can be evaluated partially, for example we might be able to fold and expression in an argument of the function but not remove the function completely
*/
public boolean partiallyFoldable() {
return false;
}

/**
* Evaluate this expression statically to a constant. It is an error to call
* this if {@link #foldable} returns false.
Expand All @@ -95,6 +102,10 @@ public Object fold(FoldContext ctx) {
throw new QlIllegalArgumentException("Should not fold expression");
}

public Expression partiallyFold(FoldContext ctx) {
throw new QlIllegalArgumentException("Should not fold expression");
}

public abstract Nullability nullable();

// the references/inputs/leaves of the expression tree
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ public static Integer limitValue(Expression limitField, String sourceText) {
*/
public static Expression.TypeResolution resolveTypeLimit(Expression limitField, String sourceText) {
if (limitField == null) {
return new Expression.TypeResolution(format(null, "Limit must be a constant integer in [{}], found [{}]", sourceText, limitField));
return new Expression.TypeResolution(
format(null, "Limit must be a constant integer in [{}], found [{}]", sourceText, limitField)
);
}
if (limitField instanceof Literal literal) {
if (literal.value() == null) {
Expand All @@ -47,6 +49,7 @@ public static Expression.TypeResolution resolveTypeLimit(Expression limitField,
}
return Expression.TypeResolution.TYPE_RESOLVED;
}

public static void postOptimizationVerificationLimit(Failures failures, Expression limitField, String sourceText) {
if (limitField == null) {
failures.add(fail(limitField, "Limit must be a constant integer in [{}], found [{}]", sourceText, limitField));
Expand All @@ -62,6 +65,7 @@ public static void postOptimizationVerificationLimit(Failures failures, Expressi
failures.add(fail(limitField, "Limit must be a constant integer in [{}], found [{}]", sourceText, limitField));
}
}

public static Expression.TypeResolution resolveTypeQuery(Expression queryField, String sourceText) {
if (queryField == null) {
return new Expression.TypeResolution(format(null, "Query must be a valid string in [{}], found [{}]", sourceText, queryField));
Expand All @@ -83,16 +87,15 @@ public static void postOptimizationVerificationQuery(Failures failures, Expressi
if (queryField instanceof Literal literal) {
String value = BytesRefs.toString(literal.value());
if (value == null) {
failures.add(
fail(queryField, "Invalid query value in [{}], found [{}]", sourceText, value)
);
failures.add(fail(queryField, "Invalid query value in [{}], found [{}]", sourceText, value));
}
} else {
// it is expected that the expression is a literal after folding
// we fail if it is not a literal
failures.add(fail(queryField, "Query must be a valid string in [{}], found [{}]", sourceText, queryField));
}
}

public static Object queryAsObject(Expression queryField, String sourceText) {
if (queryField instanceof Literal literal) {
return BytesRefs.toString(literal.value());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ private TypeResolution resolveField() {
}

private TypeResolution resolveQuery() {
TypeResolution result = isType(
TypeResolution result = isType(
query(),
QUERY_DATA_TYPES::contains,
sourceText(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@

package org.elasticsearch.xpack.esql.expression.function.vector;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.InvalidArgumentException;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.Literal;
import org.elasticsearch.xpack.esql.core.expression.MapExpression;
import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
import org.elasticsearch.xpack.esql.core.querydsl.query.Query;
Expand Down Expand Up @@ -46,6 +50,7 @@
import java.util.function.BiConsumer;

import static java.util.Map.entry;
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.K_FIELD;
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.NUM_CANDS_FIELD;
Expand All @@ -56,13 +61,14 @@
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isFoldable;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isMapExpression;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
import static org.elasticsearch.xpack.esql.core.type.DataType.FLOAT;
import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
import static org.elasticsearch.xpack.esql.expression.function.FunctionUtils.resolveTypeQuery;

public class Knn extends FullTextFunction implements OptionalArgument, VectorFunction, PostAnalysisPlanVerificationAware {
private final Logger log = LogManager.getLogger(getClass());

public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Knn", Knn::readFrom);

Expand Down Expand Up @@ -189,9 +195,16 @@ private TypeResolution resolveField() {
}

private TypeResolution resolveQuery() {
return isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector").and(
isNotNullAndFoldable(query(), sourceText(), SECOND)
);
TypeResolution result = isType(query(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.SECOND, "dense_vector")
.and(isNotNull(query(), sourceText(), SECOND));
if (result.unresolved()) {
return result;
}
result = resolveTypeQuery(query(), sourceText());
if (result.equals(TypeResolution.TYPE_RESOLVED) == false) {
return result;
}
return TypeResolution.TYPE_RESOLVED;
}

private TypeResolution resolveK() {
Expand Down Expand Up @@ -235,19 +248,59 @@ private Map<String, Object> knnQueryOptions() throws InvalidArgumentException {
return matchOptions;
}

@Override
public boolean partiallyFoldable() {
return true;
}

@Override
public Expression partiallyFold(FoldContext ctx) {
if (k instanceof Literal) {
// already folded, return self
return this;
}
Object foldedK = k.fold(ctx);
if (foldedK instanceof Number == false) {
throw new EsqlIllegalArgumentException(format(null, "K value must be a constant integer in [{}], found [{}]", source(), k()));
}
List<Expression> newChildren = new ArrayList<>(this.children());
newChildren.set(2, new Literal(source(), foldedK, INTEGER));
Expression replaced = this.replaceChildren(newChildren);
log.error("Partially folded knn function [{}] with k value [{}]", replaced, foldedK);
return replaced;
}

@Override
public List<Number> queryAsObject() {
// we need to check that we got a list and every element in the list is a number
Expression query = query();
if (query instanceof Literal literal) {
@SuppressWarnings("unchecked")
List<Number> result = ((List<Number>) literal.value());
return result;
}
throw new EsqlIllegalArgumentException(format(null, "Query value must be a list of numbers in [{}], found [{}]", source(), query));
}

int getKIntValue() {
if (k() instanceof Literal literal) {
return (int) (Number) literal.value();
}
throw new EsqlIllegalArgumentException(format(null, "K value must be a constant integer in [{}], found [{}]", source(), k()));
}

@Override
protected Query translate(TranslatorHandler handler) {
var fieldAttribute = Match.fieldAsFieldAttribute(field());

Check.notNull(fieldAttribute, "Match must have a field attribute as the first argument");
String fieldName = getNameFromFieldAttribute(fieldAttribute);
@SuppressWarnings("unchecked")
List<Number> queryFolded = (List<Number>) query().fold(FoldContext.small() /* TODO remove me */);
List<Number> queryFolded = queryAsObject();
float[] queryAsFloats = new float[queryFolded.size()];
for (int i = 0; i < queryFolded.size(); i++) {
queryAsFloats[i] = queryFolded.get(i).floatValue();
}
int kValue = ((Number) k().fold(FoldContext.small())).intValue();
int kValue = getKIntValue();

Map<String, Object> opts = queryOptions();
opts.put(K_FIELD.getPreferredName(), kValue);
Expand Down Expand Up @@ -322,12 +375,13 @@ public boolean equals(Object o) {
Knn knn = (Knn) o;
return Objects.equals(field(), knn.field())
&& Objects.equals(query(), knn.query())
&& Objects.equals(queryBuilder(), knn.queryBuilder());
&& Objects.equals(queryBuilder(), knn.queryBuilder())
&& Objects.equals(k(), knn.k());
}

@Override
public int hashCode() {
return Objects.hash(field(), query(), queryBuilder());
return Objects.hash(field(), query(), queryBuilder(), k());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ public ConstantFolding() {

@Override
public Expression rule(Expression e, LogicalOptimizerContext ctx) {
return e.foldable() ? Literal.of(ctx.foldCtx(), e) : e;
if (e.foldable()) {
return Literal.of(ctx.foldCtx(), e);
} else if (e.partiallyFoldable()) {
return e.partiallyFold(ctx.foldCtx());
}
return e;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.core.config.Configurator;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.xpack.ql.tree.Node;
import org.elasticsearch.xpack.ql.tree.NodeUtils;
Expand Down Expand Up @@ -137,6 +138,7 @@ protected final TreeType execute(TreeType plan) {
}

protected final ExecutionInfo executeWithInfo(TreeType plan) {
Configurator.setLevel(log.getName(), org.apache.logging.log4j.Level.TRACE);
TreeType currentPlan = plan;

long totalDuration = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,23 @@ setup:
type: long
name:
type: keyword
image_vector:
type: dense_vector
dims: 3
index: true
similarity: l2_norm

- do:
bulk:
index: employees
refresh: true
body:
- { "index": { } }
- { "hire_date": "2020-01-01", "salary_change": 100.5, "salary": 50000, "salary_change_long": 100, "name": "Alice Smith" }
- { "hire_date": "2020-01-01", "salary_change": 100.5, "salary": 50000, "salary_change_long": 100, "name": "Alice Smith", "image_vector": [ 0.1, 0.2, 0.3 ] }
- { "index": { } }
- { "hire_date": "2021-01-01", "salary_change": 200.5, "salary": 60000, "salary_change_long": 200, "name": "Bob Johnson" }
- { "hire_date": "2021-01-01", "salary_change": 200.5, "salary": 60000, "salary_change_long": 200, "name": "Bob Johnson", "image_vector": [ 0.4, 0.5, 0.6 ] }
- { "index": { } }
- { "hire_date": "2019-01-01", "salary_change": 50.5, "salary": 40000, "salary_change_long": 50, "name": "Charlie Smith" }
- { "hire_date": "2019-01-01", "salary_change": 50.5, "salary": 40000, "salary_change_long": 50, "name": "Charlie Smith", "image_vector": [ 0.7, 0.8, 0.9 ] }

---
Top function with constant folding:
Expand Down Expand Up @@ -279,3 +285,39 @@ Foldable query using QSTR on name but with non-foldable expression:
- match: { error.type: "verification_exception" }
- contains: { error.reason: "Query must be a valid string in [QSTR(CONCAT(name, \"Bob\"))], found [CONCAT(name, \"Bob\")]" }

---

Foldable query using KNN on image_vector:
- do:
esql.query:
body:
query: |
FROM employees
| WHERE KNN(image_vector, [0.4, 0.5, 0.9], 1 + 1)
| KEEP hire_date, salary, salary_change, salary_change_long, name, image_vector
| LIMIT 5
- match: { columns.0.name: "hire_date" }
- match: { columns.1.name: "salary" }
- match: { columns.2.name: "salary_change" }
- match: { columns.3.name: "salary_change_long" }
- match: { columns.4.name: "name" }
- match: { columns.5.name: "image_vector" }
- length: { values: 2 }
# The closest vectors to [0.4, 0.5, 0.6] are Bob Johnson and Charlie Smith
- match: { values.0.4: "Bob Johnson" }
- match: { values.1.4: "Charlie Smith" }

---

Foldable query using KNN on image_vector but with non-foldable expression:
- do:
catch: bad_request
esql.query:
body:
query: |
FROM employees
| WHERE KNN(image_vector, [0.4, 0.5, 0.9], 1+salary)
| KEEP hire_date, salary, salary_change, salary_change_long, name, image_vector
| LIMIT 5
- match: { error.type: "verification_exception" }
- contains: { error.reason: "third argument of [KNN(image_vector, [0.4, 0.5, 0.9], 1+salary)] must be a constant, received [1+salary]" }