Skip to content

Commit 0c18eba

Browse files
committed
Add CosineSimilarity function
1 parent 72b5c01 commit 0c18eba

File tree

5 files changed

+242
-8
lines changed

5 files changed

+242
-8
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1215,7 +1215,9 @@ public enum Cap {
12151215
/**
12161216
* (Re)Added EXPLAIN command
12171217
*/
1218-
EXPLAIN(Build.current().isSnapshot());
1218+
EXPLAIN(Build.current().isSnapshot()),
1219+
1220+
COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot());
12191221

12201222
private final boolean enabled;
12211223

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
package org.elasticsearch.xpack.esql.expression;
99

1010
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11-
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
1211
import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables;
1312
import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
1413
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
@@ -85,7 +84,7 @@
8584
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike;
8685
import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLikeList;
8786
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
88-
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
87+
import org.elasticsearch.xpack.esql.expression.function.vector.VectorWritables;
8988
import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
9089
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
9190
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull;
@@ -259,9 +258,6 @@ private static List<NamedWriteableRegistry.Entry> fullText() {
259258
}
260259

261260
private static List<NamedWriteableRegistry.Entry> vector() {
262-
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
263-
return List.of(Knn.ENTRY);
264-
}
265-
return List.of();
261+
return VectorWritables.getNamedWritables();
266262
}
267263
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@
179179
import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper;
180180
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim;
181181
import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
182+
import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity;
182183
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
183184
import org.elasticsearch.xpack.esql.parser.ParsingException;
184185
import org.elasticsearch.xpack.esql.session.Configuration;
@@ -487,7 +488,8 @@ private static FunctionDefinition[][] snapshotFunctions() {
487488
def(StGeotileToString.class, StGeotileToString::new, "st_geotile_to_string"),
488489
def(StGeohex.class, StGeohex::new, "st_geohex"),
489490
def(StGeohexToLong.class, StGeohexToLong::new, "st_geohex_to_long"),
490-
def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string") } };
491+
def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"),
492+
def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine_similarity") } };
491493
}
492494

493495
public EsqlFunctionRegistry snapshotRegistry() {
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.vector;
9+
10+
import org.apache.lucene.index.VectorSimilarityFunction;
11+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
12+
import org.elasticsearch.common.io.stream.StreamInput;
13+
import org.elasticsearch.common.io.stream.StreamOutput;
14+
import org.elasticsearch.compute.data.Block;
15+
import org.elasticsearch.compute.data.DoubleVector;
16+
import org.elasticsearch.compute.data.FloatBlock;
17+
import org.elasticsearch.compute.data.Page;
18+
import org.elasticsearch.compute.operator.DriverContext;
19+
import org.elasticsearch.compute.operator.EvalOperator;
20+
import org.elasticsearch.xpack.esql.core.expression.Expression;
21+
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
22+
import org.elasticsearch.xpack.esql.core.tree.Source;
23+
import org.elasticsearch.xpack.esql.core.type.DataType;
24+
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
25+
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
26+
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
27+
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
28+
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow;
29+
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.AbstractMultivalueFunction;
30+
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
31+
32+
import java.io.IOException;
33+
import java.util.List;
34+
35+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
36+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
37+
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
38+
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
39+
40+
public class CosineSimilarity extends EsqlScalarFunction implements VectorFunction {
41+
42+
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
43+
Expression.class,
44+
"CosineSimilarity",
45+
CosineSimilarity::new
46+
);
47+
48+
private Expression left;
49+
private Expression right;
50+
51+
@FunctionInfo(
52+
returnType = "double",
53+
preview = true,
54+
description = "Calculates the cosine similarity between two dense_vectors.",
55+
appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
56+
)
57+
public CosineSimilarity(Source source, Expression left, Expression right) {
58+
super(source, List.of(left, right));
59+
this.left = left;
60+
this.right = right;
61+
}
62+
63+
@Override
64+
public DataType dataType() {
65+
return DataType.DOUBLE;
66+
}
67+
68+
private CosineSimilarity(StreamInput in) throws IOException {
69+
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class));
70+
}
71+
72+
@Override
73+
protected TypeResolution resolveType() {
74+
if (childrenResolved() == false) {
75+
return new TypeResolution("Unresolved children");
76+
}
77+
78+
return checkParam(left()).and(checkParam(right()));
79+
}
80+
81+
private TypeResolution checkParam(Expression param) {
82+
return isNotNull(param, sourceText(), FIRST).and(isType(param, dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector"));
83+
}
84+
85+
@Override
86+
public void writeTo(StreamOutput out) throws IOException {
87+
source().writeTo(out);
88+
out.writeNamedWriteable(left());
89+
out.writeNamedWriteable(right());
90+
}
91+
92+
@Override
93+
public Expression replaceChildren(List<Expression> newChildren) {
94+
return new CosineSimilarity(source(), newChildren.get(0), newChildren.get(1));
95+
}
96+
97+
public Expression left() {
98+
return left;
99+
}
100+
101+
public Expression right() {
102+
return right;
103+
}
104+
105+
@Override
106+
protected NodeInfo<? extends Expression> info() {
107+
return NodeInfo.create(this, Pow::new, left(), right());
108+
}
109+
110+
@Override
111+
public String getWriteableName() {
112+
return ENTRY.name;
113+
}
114+
115+
@Override
116+
public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
117+
return new EvaluatorFactory(toEvaluator.apply(left()), toEvaluator.apply(right()));
118+
}
119+
120+
private record EvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory left, EvalOperator.ExpressionEvaluator.Factory right)
121+
implements
122+
EvalOperator.ExpressionEvaluator.Factory {
123+
@Override
124+
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
125+
return new Evaluator(context, left.get(context), right.get(context));
126+
}
127+
128+
@Override
129+
public String toString() {
130+
return "CosineSimilarity[left=" + left + ", right=" + right + "]";
131+
}
132+
}
133+
134+
/**
135+
* Evaluator for {@link CosineSimilarity}. Not generated and doesn’t extend from
136+
* {@link AbstractMultivalueFunction.AbstractEvaluator} because it’s different from {@link org.elasticsearch.compute.ann.MvEvaluator}
137+
* or scalar evaluators.
138+
*
139+
* We can probably generalize to a common class or use its own annotation / evaluator template
140+
*/
141+
private static class Evaluator implements EvalOperator.ExpressionEvaluator {
142+
private final DriverContext context;
143+
private final EvalOperator.ExpressionEvaluator left;
144+
private final EvalOperator.ExpressionEvaluator right;
145+
146+
Evaluator(DriverContext context, EvalOperator.ExpressionEvaluator left, EvalOperator.ExpressionEvaluator right) {
147+
this.context = context;
148+
this.left = left;
149+
this.right = right;
150+
}
151+
152+
@Override
153+
public final Block eval(Page page) {
154+
try (FloatBlock leftBlock = (FloatBlock) left.eval(page); FloatBlock rightBlock = (FloatBlock) right.eval(page)) {
155+
int positionCount = page.getPositionCount();
156+
if (positionCount == 0) {
157+
return context.blockFactory().newConstantFloatBlockWith(0F, 0);
158+
}
159+
160+
int dimensions = leftBlock.getValueCount(0);
161+
int dimsRight = rightBlock.getValueCount(0);
162+
assert dimensions == dimsRight
163+
: "Left and right vector must have the same value count, but got left: " + dimensions + ", right: " + dimsRight;
164+
float[] leftScratch = new float[dimensions];
165+
float[] rightScratch = new float[dimensions];
166+
try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) {
167+
for (int p = 0; p < positionCount; p++) {
168+
assert leftBlock.getValueCount(p) == dimensions
169+
: "Left vector must have the same value count for all positions, but got left: "
170+
+ leftBlock.getValueCount(p)
171+
+ ", expected: "
172+
+ dimensions;
173+
assert rightBlock.getValueCount(p) == dimensions
174+
: "Left vector must have the same value count for all positions, but got left: "
175+
+ rightBlock.getValueCount(p)
176+
+ ", expected: "
177+
+ dimensions;
178+
179+
readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch);
180+
readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch);
181+
float result = VectorSimilarityFunction.COSINE.compare(leftScratch, rightScratch);
182+
builder.appendDouble(result);
183+
}
184+
return builder.build().asBlock();
185+
}
186+
}
187+
}
188+
189+
private void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) {
190+
for (int i = 0; i < dimensions; i++) {
191+
scratch[i] = block.getFloat(position + i);
192+
}
193+
}
194+
195+
@Override
196+
public final String toString() {
197+
return "CosineSimilarity[left=" + left + ", right=" + right + "]";
198+
}
199+
200+
@Override
201+
public void close() {}
202+
}
203+
204+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.vector;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
12+
13+
import java.util.ArrayList;
14+
import java.util.Collections;
15+
import java.util.List;
16+
17+
public class VectorWritables {
18+
public static List<NamedWriteableRegistry.Entry> getNamedWritables() {
19+
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
20+
21+
if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
22+
return List.of(Knn.ENTRY);
23+
}
24+
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
25+
entries.add(CosineSimilarity.ENTRY);
26+
}
27+
28+
return Collections.unmodifiableList(entries);
29+
}
30+
}

0 commit comments

Comments
 (0)