Skip to content

Commit b441d60

Browse files
committed
Use lambda instead of overriden methods
1 parent 5098373 commit b441d60

File tree

2 files changed

+66
-89
lines changed

2 files changed

+66
-89
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
1111
import org.elasticsearch.common.io.stream.StreamInput;
12-
import org.elasticsearch.compute.operator.DriverContext;
1312
import org.elasticsearch.compute.operator.EvalOperator;
1413
import org.elasticsearch.xpack.esql.core.expression.Expression;
1514
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
@@ -68,22 +67,10 @@ public String getWriteableName() {
6867

6968
@Override
7069
public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) {
71-
return new VectorSimilarityFunction.SimilarityEvaluatorFactory(toEvaluator.apply(left()), toEvaluator.apply(right())) {
72-
@Override
73-
protected SimilarityEvaluator getSimilarityEvaluator(DriverContext context) {
74-
return new CosineSimilarityEvaluator(context, left.get(context), right.get(context));
75-
}
76-
};
77-
}
78-
79-
private static class CosineSimilarityEvaluator extends VectorSimilarityFunction.SimilarityEvaluator {
80-
CosineSimilarityEvaluator(DriverContext context, EvalOperator.ExpressionEvaluator left, EvalOperator.ExpressionEvaluator right) {
81-
super(context, left, right);
82-
}
83-
84-
@Override
85-
protected float calculateSimilarity(float[] leftScratch, float[] rightScratch) {
86-
return COSINE.compare(leftScratch, rightScratch);
87-
}
70+
return new SimilarityEvaluatorFactory(
71+
toEvaluator.apply(left()),
72+
toEvaluator.apply(right()),
73+
(leftScratch, rightScratch) -> COSINE.compare(leftScratch, rightScratch)
74+
);
8875
}
8976
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java

Lines changed: 61 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.elasticsearch.xpack.esql.core.type.DataType;
2020
import org.elasticsearch.xpack.esql.expression.function.Param;
2121
import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
22-
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.AbstractMultivalueFunction;
2322

2423
import java.io.IOException;
2524
import java.util.List;
@@ -86,93 +85,84 @@ public Expression right() {
8685
return right;
8786
}
8887

89-
protected abstract class SimilarityEvaluatorFactory implements EvalOperator.ExpressionEvaluator.Factory {
88+
@FunctionalInterface
89+
public interface SimilarityEvaluatorFunction {
90+
float calculateSimilarity(float[] leftScratch, float[] rightScratch);
91+
}
9092

91-
protected final EvalOperator.ExpressionEvaluator.Factory left;
92-
protected final EvalOperator.ExpressionEvaluator.Factory right;
93+
protected class SimilarityEvaluatorFactory implements EvalOperator.ExpressionEvaluator.Factory {
9394

94-
SimilarityEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory left, EvalOperator.ExpressionEvaluator.Factory right) {
95-
this.left = left;
96-
this.right = right;
97-
}
95+
private final EvalOperator.ExpressionEvaluator.Factory left;
96+
private final EvalOperator.ExpressionEvaluator.Factory right;
97+
private final SimilarityEvaluatorFunction similarityFunction;
9898

99-
@Override
100-
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
101-
return getSimilarityEvaluator(context);
102-
}
103-
104-
protected abstract SimilarityEvaluator getSimilarityEvaluator(DriverContext context);
105-
}
106-
107-
/**
108-
* Evaluator for {@link CosineSimilarity}. Not generated and doesn’t extend from
109-
* {@link AbstractMultivalueFunction.AbstractEvaluator} because it’s different from {@link org.elasticsearch.compute.ann.MvEvaluator}
110-
* or scalar evaluators.
111-
* <p>
112-
* We can probably generalize to a common class or use its own annotation / evaluator template
113-
*/
114-
protected abstract static class SimilarityEvaluator implements EvalOperator.ExpressionEvaluator {
115-
private final DriverContext context;
116-
private final EvalOperator.ExpressionEvaluator left;
117-
private final EvalOperator.ExpressionEvaluator right;
118-
119-
SimilarityEvaluator(DriverContext context, EvalOperator.ExpressionEvaluator left, EvalOperator.ExpressionEvaluator right) {
120-
this.context = context;
99+
SimilarityEvaluatorFactory(
100+
EvalOperator.ExpressionEvaluator.Factory left,
101+
EvalOperator.ExpressionEvaluator.Factory right,
102+
SimilarityEvaluatorFunction similarityFunction
103+
) {
121104
this.left = left;
122105
this.right = right;
106+
this.similarityFunction = similarityFunction;
123107
}
124108

125109
@Override
126-
public final Block eval(Page page) {
127-
try (FloatBlock leftBlock = (FloatBlock) left.eval(page); FloatBlock rightBlock = (FloatBlock) right.eval(page)) {
128-
int positionCount = page.getPositionCount();
129-
if (positionCount == 0) {
130-
return context.blockFactory().newConstantFloatBlockWith(0F, 0);
110+
public EvalOperator.ExpressionEvaluator get(DriverContext context) {
111+
return new EvalOperator.ExpressionEvaluator() {
112+
@Override
113+
public Block eval(Page page) {
114+
try (
115+
FloatBlock leftBlock = (FloatBlock) left.get(context).eval(page);
116+
FloatBlock rightBlock = (FloatBlock) right.get(context).eval(page)
117+
) {
118+
int positionCount = page.getPositionCount();
119+
if (positionCount == 0) {
120+
return context.blockFactory().newConstantFloatBlockWith(0F, 0);
121+
}
122+
123+
int dimensions = leftBlock.getValueCount(0);
124+
int dimsRight = rightBlock.getValueCount(0);
125+
assert dimensions == dimsRight
126+
: "Left and right vector must have the same value count, but got left: " + dimensions + ", right: " + dimsRight;
127+
float[] leftScratch = new float[dimensions];
128+
float[] rightScratch = new float[dimensions];
129+
try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) {
130+
for (int p = 0; p < positionCount; p++) {
131+
assert leftBlock.getValueCount(p) == dimensions
132+
: "Left vector must have the same value count for all positions, but got left: "
133+
+ leftBlock.getValueCount(p)
134+
+ ", expected: "
135+
+ dimensions;
136+
assert rightBlock.getValueCount(p) == dimensions
137+
: "Left vector must have the same value count for all positions, but got left: "
138+
+ rightBlock.getValueCount(p)
139+
+ ", expected: "
140+
+ dimensions;
141+
142+
readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch);
143+
readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch);
144+
float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch);
145+
builder.appendDouble(result);
146+
}
147+
return builder.build().asBlock();
148+
}
149+
}
131150
}
132151

133-
int dimensions = leftBlock.getValueCount(0);
134-
int dimsRight = rightBlock.getValueCount(0);
135-
assert dimensions == dimsRight
136-
: "Left and right vector must have the same value count, but got left: " + dimensions + ", right: " + dimsRight;
137-
float[] leftScratch = new float[dimensions];
138-
float[] rightScratch = new float[dimensions];
139-
try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) {
140-
for (int p = 0; p < positionCount; p++) {
141-
assert leftBlock.getValueCount(p) == dimensions
142-
: "Left vector must have the same value count for all positions, but got left: "
143-
+ leftBlock.getValueCount(p)
144-
+ ", expected: "
145-
+ dimensions;
146-
assert rightBlock.getValueCount(p) == dimensions
147-
: "Left vector must have the same value count for all positions, but got left: "
148-
+ rightBlock.getValueCount(p)
149-
+ ", expected: "
150-
+ dimensions;
151-
152-
readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch);
153-
readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch);
154-
float result = calculateSimilarity(leftScratch, rightScratch);
155-
builder.appendDouble(result);
156-
}
157-
return builder.build().asBlock();
152+
@Override
153+
public void close() {}
154+
155+
@Override
156+
public String toString() {
157+
return "ExpressionEvaluator[left=" + left + ", right=" + right + "]";
158158
}
159-
}
159+
};
160160
}
161161

162-
protected abstract float calculateSimilarity(float[] leftScratch, float[] rightScratch);
163-
164162
private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) {
165163
for (int i = 0; i < dimensions; i++) {
166164
scratch[i] = block.getFloat(position + i);
167165
}
168166
}
169-
170-
@Override
171-
public final String toString() {
172-
return getClass().getSimpleName() + "=" + left + ", right=" + right + "]";
173-
}
174-
175-
@Override
176-
public void close() {}
177167
}
178168
}

0 commit comments

Comments
 (0)