Skip to content

Commit ffc52aa

Browse files
committed
Add ToDenseVectorFromStringEvaluator
1 parent 8e7247a commit ffc52aa

File tree

2 files changed

+123
-5
lines changed

2 files changed

+123
-5
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToDenseVector.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.compute.ann.ConvertEvaluator;
13+
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisVerificationAware;
1314
import org.elasticsearch.xpack.esql.core.expression.Expression;
1415
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
1516
import org.elasticsearch.xpack.esql.core.tree.Source;
@@ -25,9 +26,10 @@
2526
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
2627
import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
2728
import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER;
29+
import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD;
2830
import static org.elasticsearch.xpack.esql.core.type.DataType.LONG;
2931

30-
public class ToDenseVector extends AbstractConvertFunction {
32+
public class ToDenseVector extends AbstractConvertFunction implements PostAnalysisVerificationAware {
3133
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
3234
Expression.class,
3335
"ToDenseVector",
@@ -38,20 +40,21 @@ public class ToDenseVector extends AbstractConvertFunction {
3840
Map.entry(DENSE_VECTOR, (source, fieldEval) -> fieldEval),
3941
Map.entry(LONG, ToDenseVectorFromLongEvaluator.Factory::new),
4042
Map.entry(INTEGER, ToDenseVectorFromIntEvaluator.Factory::new),
41-
Map.entry(DOUBLE, ToDenseVectorFromDoubleEvaluator.Factory::new)
43+
Map.entry(DOUBLE, ToDenseVectorFromDoubleEvaluator.Factory::new),
44+
Map.entry(KEYWORD, ToDenseVectorFromStringEvaluator.Factory::new)
4245
);
4346

4447
@FunctionInfo(
4548
returnType = "dense_vector",
46-
description = "Converts a multi-valued input of numbers to a dense_vector.",
49+
description = "Converts a multi-valued input of numbers, or a hexadecimal string, to a dense_vector.",
4750
examples = @Example(file = "dense_vector", tag = "to_dense_vector-ints")
4851
)
4952
public ToDenseVector(
5053
Source source,
5154
@Param(
5255
name = "field",
53-
type = {"double", "long", "integer"},
54-
description = "multi-valued input of numbers to convert."
56+
type = {"double", "long", "integer", "keyword"},
57+
description = "multi-valued input of numbers or hexadecimal string to convert."
5558
) Expression field
5659
) {
5760
super(source, field);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.scalar.convert;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.apache.lucene.util.RamUsageEstimator;
12+
import org.elasticsearch.compute.data.Block;
13+
import org.elasticsearch.compute.data.BytesRefBlock;
14+
import org.elasticsearch.compute.data.FloatBlock;
15+
import org.elasticsearch.compute.data.Vector;
16+
import org.elasticsearch.compute.operator.DriverContext;
17+
import org.elasticsearch.compute.operator.EvalOperator;
18+
import org.elasticsearch.core.Releasables;
19+
import org.elasticsearch.xpack.esql.core.tree.Source;
20+
21+
import java.util.HexFormat;
22+
23+
public class ToDenseVectorFromStringEvaluator extends AbstractConvertFunction.AbstractEvaluator {
24+
private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(
25+
ToDenseVectorFromStringEvaluator.class
26+
);
27+
28+
private final EvalOperator.ExpressionEvaluator field;
29+
30+
public ToDenseVectorFromStringEvaluator(Source source, EvalOperator.ExpressionEvaluator field, DriverContext driverContext) {
31+
super(driverContext, source);
32+
this.field = field;
33+
}
34+
35+
@Override
36+
protected EvalOperator.ExpressionEvaluator next() {
37+
return field;
38+
}
39+
40+
@Override
41+
protected Block evalVector(Vector v) {
42+
return evalBlock(v.asBlock());
43+
}
44+
45+
@Override
46+
public Block evalBlock(Block b) {
47+
BytesRefBlock block = (BytesRefBlock) b;
48+
int positionCount = block.getPositionCount();
49+
int dimensions = 0;
50+
BytesRef scratch = new BytesRef();
51+
try (FloatBlock.Builder builder = driverContext.blockFactory().newFloatBlockBuilder(positionCount * dimensions)) {
52+
for (int p = 0; p < positionCount; p++) {
53+
if (block.isNull(p)) {
54+
builder.appendNull();
55+
} else {
56+
scratch = block.getBytesRef(p, scratch);
57+
byte[] bytes = HexFormat.of().parseHex(scratch.utf8ToString());
58+
if (bytes.length == 0) {
59+
builder.appendNull();
60+
continue;
61+
}
62+
if (dimensions == 0) {
63+
dimensions = bytes.length;
64+
} else {
65+
if (bytes.length != dimensions) {
66+
throw new IllegalArgumentException("All dense_vector must have the same number of dimensions. Expected: "
67+
+ dimensions + ", found: " + bytes.length);
68+
}
69+
}
70+
builder.beginPositionEntry();
71+
for (byte value : bytes) {
72+
builder.appendFloat(value);
73+
}
74+
builder.endPositionEntry();
75+
}
76+
}
77+
return builder.build();
78+
}
79+
}
80+
81+
@Override
82+
public String toString() {
83+
return "ToDenseVectorFromStringEvaluator[s=" + field + ']';
84+
}
85+
86+
@Override
87+
public long baseRamBytesUsed() {
88+
return BASE_RAM_BYTES_USED + field.baseRamBytesUsed();
89+
}
90+
91+
@Override
92+
public void close() {
93+
Releasables.closeExpectNoException(field);
94+
}
95+
96+
public static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
97+
private final Source source;
98+
private final EvalOperator.ExpressionEvaluator.Factory field;
99+
100+
public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory field) {
101+
this.source = source;
102+
this.field = field;
103+
}
104+
105+
@Override
106+
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
107+
return new ToDenseVectorFromStringEvaluator(source, field.get(context), context);
108+
}
109+
110+
@Override
111+
public String toString() {
112+
return "ToDenseVectorFromStringEvaluator[s=" + field + ']';
113+
}
114+
}
115+
}

0 commit comments

Comments
 (0)