Skip to content

Commit 5098373

Browse files
committed
Extract superclass, use overriden method
1 parent 16542b5 commit 5098373

File tree

2 files changed

+198
-145
lines changed

2 files changed

+198
-145
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java

Lines changed: 20 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -7,48 +7,33 @@
77

88
package org.elasticsearch.xpack.esql.expression.function.vector;
99

10-
import org.apache.lucene.index.VectorSimilarityFunction;
1110
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1211
import org.elasticsearch.common.io.stream.StreamInput;
13-
import org.elasticsearch.common.io.stream.StreamOutput;
14-
import org.elasticsearch.compute.data.Block;
15-
import org.elasticsearch.compute.data.DoubleVector;
16-
import org.elasticsearch.compute.data.FloatBlock;
17-
import org.elasticsearch.compute.data.Page;
1812
import org.elasticsearch.compute.operator.DriverContext;
1913
import org.elasticsearch.compute.operator.EvalOperator;
2014
import org.elasticsearch.xpack.esql.core.expression.Expression;
2115
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
2216
import org.elasticsearch.xpack.esql.core.tree.Source;
23-
import org.elasticsearch.xpack.esql.core.type.DataType;
2417
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
2518
import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
2619
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
2720
import org.elasticsearch.xpack.esql.expression.function.Param;
28-
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
2921
import org.elasticsearch.xpack.esql.expression.function.scalar.math.Pow;
30-
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.AbstractMultivalueFunction;
3122
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
3223

3324
import java.io.IOException;
3425
import java.util.List;
3526

36-
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
37-
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
38-
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
39-
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
27+
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
4028

41-
public class CosineSimilarity extends EsqlScalarFunction implements VectorFunction {
29+
public class CosineSimilarity extends org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction {
4230

4331
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
4432
Expression.class,
4533
"CosineSimilarity",
4634
CosineSimilarity::new
4735
);
4836

49-
private Expression left;
50-
private Expression right;
51-
5237
@FunctionInfo(
5338
returnType = "double",
5439
preview = true,
@@ -57,58 +42,20 @@ public class CosineSimilarity extends EsqlScalarFunction implements VectorFuncti
5742
)
5843
public CosineSimilarity(
5944
Source source,
60-
@Param(name = "left", type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity")
61-
Expression left,
62-
@Param(name = "right", type = { "dense_vector" }, description = "second dense_vector to calculate cosine similarity")
63-
Expression right
45+
@Param(name = "left", type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity") Expression left,
46+
@Param(
47+
name = "right",
48+
type = { "dense_vector" },
49+
description = "second dense_vector to calculate cosine similarity"
50+
) Expression right
6451
) {
65-
super(source, List.of(left, right));
66-
this.left = left;
67-
this.right = right;
52+
super(source, List.of(left, right), left, right);
6853
}
6954

7055
private CosineSimilarity(StreamInput in) throws IOException {
7156
this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class));
7257
}
7358

74-
@Override
75-
public void writeTo(StreamOutput out) throws IOException {
76-
source().writeTo(out);
77-
out.writeNamedWriteable(left());
78-
out.writeNamedWriteable(right());
79-
}
80-
81-
@Override
82-
public DataType dataType() {
83-
return DataType.DOUBLE;
84-
}
85-
86-
@Override
87-
protected TypeResolution resolveType() {
88-
if (childrenResolved() == false) {
89-
return new TypeResolution("Unresolved children");
90-
}
91-
92-
return checkDenseVectorParam(left()).and(checkDenseVectorParam(right()));
93-
}
94-
95-
private TypeResolution checkDenseVectorParam(Expression param) {
96-
return isNotNull(param, sourceText(), FIRST).and(isType(param, dt -> dt == DENSE_VECTOR, sourceText(), FIRST, "dense_vector"));
97-
}
98-
99-
@Override
100-
public Expression replaceChildren(List<Expression> newChildren) {
101-
return new CosineSimilarity(source(), newChildren.get(0), newChildren.get(1));
102-
}
103-
104-
public Expression left() {
105-
return left;
106-
}
107-
108-
public Expression right() {
109-
return right;
110-
}
111-
11259
@Override
11360
protected NodeInfo<? extends Expression> info() {
11461
return NodeInfo.create(this, Pow::new, left(), right());
@@ -121,94 +68,22 @@ public String getWriteableName() {
12168

12269
@Override
12370
public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
124-
return new EvaluatorFactory(toEvaluator.apply(left()), toEvaluator.apply(right()));
125-
}
126-
127-
private record EvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory left, EvalOperator.ExpressionEvaluator.Factory right)
128-
implements
129-
EvalOperator.ExpressionEvaluator.Factory {
130-
@Override
131-
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
132-
return new Evaluator(context, left.get(context), right.get(context));
133-
}
134-
135-
@Override
136-
public String toString() {
137-
return "CosineSimilarity[left=" + left + ", right=" + right + "]";
138-
}
139-
}
140-
141-
/**
142-
* Evaluator for {@link CosineSimilarity}. Not generated and doesn’t extend from
143-
* {@link AbstractMultivalueFunction.AbstractEvaluator} because it’s different from {@link org.elasticsearch.compute.ann.MvEvaluator}
144-
* or scalar evaluators.
145-
*
146-
* We can probably generalize to a common class or use its own annotation / evaluator template
147-
*/
148-
private static class Evaluator implements EvalOperator.ExpressionEvaluator {
149-
private final DriverContext context;
150-
private final EvalOperator.ExpressionEvaluator left;
151-
private final EvalOperator.ExpressionEvaluator right;
152-
153-
Evaluator(DriverContext context, EvalOperator.ExpressionEvaluator left, EvalOperator.ExpressionEvaluator right) {
154-
this.context = context;
155-
this.left = left;
156-
this.right = right;
157-
}
158-
159-
@Override
160-
public final Block eval(Page page) {
161-
try (FloatBlock leftBlock = (FloatBlock) left.eval(page); FloatBlock rightBlock = (FloatBlock) right.eval(page)) {
162-
int positionCount = page.getPositionCount();
163-
if (positionCount == 0) {
164-
return context.blockFactory().newConstantFloatBlockWith(0F, 0);
165-
}
166-
167-
int dimensions = leftBlock.getValueCount(0);
168-
int dimsRight = rightBlock.getValueCount(0);
169-
assert dimensions == dimsRight
170-
: "Left and right vector must have the same value count, but got left: " + dimensions + ", right: " + dimsRight;
171-
float[] leftScratch = new float[dimensions];
172-
float[] rightScratch = new float[dimensions];
173-
try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) {
174-
for (int p = 0; p < positionCount; p++) {
175-
assert leftBlock.getValueCount(p) == dimensions
176-
: "Left vector must have the same value count for all positions, but got left: "
177-
+ leftBlock.getValueCount(p)
178-
+ ", expected: "
179-
+ dimensions;
180-
assert rightBlock.getValueCount(p) == dimensions
181-
: "Left vector must have the same value count for all positions, but got left: "
182-
+ rightBlock.getValueCount(p)
183-
+ ", expected: "
184-
+ dimensions;
185-
186-
readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch);
187-
readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch);
188-
float result = calculateSimilarity(leftScratch, rightScratch);
189-
builder.appendDouble(result);
190-
}
191-
return builder.build().asBlock();
192-
}
71+
return new VectorSimilarityFunction.SimilarityEvaluatorFactory(toEvaluator.apply(left()), toEvaluator.apply(right())) {
72+
@Override
73+
protected SimilarityEvaluator getSimilarityEvaluator(DriverContext context) {
74+
return new CosineSimilarityEvaluator(context, left.get(context), right.get(context));
19375
}
194-
}
195-
196-
private float calculateSimilarity(float[] leftScratch, float[] rightScratch) {
197-
return VectorSimilarityFunction.COSINE.compare(leftScratch, rightScratch);
198-
}
76+
};
77+
}
19978

200-
private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) {
201-
for (int i = 0; i < dimensions; i++) {
202-
scratch[i] = block.getFloat(position + i);
203-
}
79+
private static class CosineSimilarityEvaluator extends VectorSimilarityFunction.SimilarityEvaluator {
80+
CosineSimilarityEvaluator(DriverContext context, EvalOperator.ExpressionEvaluator left, EvalOperator.ExpressionEvaluator right) {
81+
super(context, left, right);
20482
}
20583

20684
@Override
207-
public final String toString() {
208-
return "CosineSimilarity[left=" + left + ", right=" + right + "]";
85+
protected float calculateSimilarity(float[] leftScratch, float[] rightScratch) {
86+
return COSINE.compare(leftScratch, rightScratch);
20987
}
210-
211-
@Override
212-
public void close() {}
21388
}
21489
}

0 commit comments

Comments
 (0)