Skip to content

Commit b67e121

Browse files
committed
Analyzer checks foldables
1 parent ffc52aa commit b67e121

File tree

1 file changed

+16
-7
lines changed
  • x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis

1 file changed

+16
-7
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.esql.analysis;
99

10+
import org.apache.lucene.util.BytesRef;
1011
import org.elasticsearch.common.logging.HeaderWarning;
1112
import org.elasticsearch.common.logging.LoggerMessageFormat;
1213
import org.elasticsearch.common.lucene.BytesRefs;
@@ -139,6 +140,7 @@
139140
import java.util.Comparator;
140141
import java.util.HashMap;
141142
import java.util.HashSet;
143+
import java.util.HexFormat;
142144
import java.util.LinkedHashMap;
143145
import java.util.List;
144146
import java.util.Map;
@@ -1672,23 +1674,31 @@ private static Expression processVectorFunction(org.elasticsearch.xpack.esql.cor
16721674
int vectorArgsCount = ((VectorFunction)vectorFunction).vectorArgumentsCount();
16731675
for (int i = 0; i < args.size(); i++) {
16741676
Expression arg = args.get(i);
1675-
if (i < vectorArgsCount && arg.resolved() && arg.dataType().isNumeric()) {
1677+
if (i < vectorArgsCount && arg.resolved()) {
16761678
if (arg.foldable()) {
1677-
Object folded = arg.fold(FoldContext.small() /* TODO remove me */);
1678-
if (folded instanceof List) {
1679+
Object folded = arg.fold(FoldContext.small());
1680+
List<Float> floatVector = null;
1681+
if (folded instanceof List && arg.dataType().isNumeric()) {
16791682
// Convert to floats so blocks are created accordingly
1680-
List<Float> floatVector;
16811683
if (arg.dataType() == FLOAT) {
16821684
floatVector = (List<Float>) folded;
16831685
} else {
16841686
floatVector = ((List<Number>) folded).stream().map(Number::floatValue).collect(Collectors.toList());
16851687
}
1688+
} else if (folded instanceof BytesRef hexString && arg.dataType() == KEYWORD) {
1689+
byte[] bytes = HexFormat.of().parseHex(hexString.utf8ToString());
1690+
floatVector = new ArrayList<>();
1691+
for (byte value : bytes) {
1692+
floatVector.add((float) value);
1693+
}
1694+
}
1695+
if (floatVector != null) {
16861696
Literal denseVector = new Literal(arg.source(), floatVector, DataType.DENSE_VECTOR);
16871697
newArgs.add(denseVector);
16881698
continue;
16891699
}
1690-
} else {
1691-
// add casting function
1700+
} else if ((arg instanceof ToDenseVector == false) && (arg.dataType().isNumeric() || arg.dataType() == KEYWORD)) {
1701+
// add casting function if it's not already there
16921702
newArgs.add(new ToDenseVector(arg.source(), arg));
16931703
continue;
16941704
}
@@ -1698,7 +1708,6 @@ private static Expression processVectorFunction(org.elasticsearch.xpack.esql.cor
16981708

16991709
return vectorFunction.replaceChildren(newArgs);
17001710
}
1701-
17021711
}
17031712

17041713
/**

0 commit comments

Comments
 (0)