|
19 | 19 | import org.elasticsearch.xpack.esql.core.type.DataType; |
20 | 20 | import org.elasticsearch.xpack.esql.expression.function.Param; |
21 | 21 | import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; |
22 | | -import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.AbstractMultivalueFunction; |
23 | 22 |
|
24 | 23 | import java.io.IOException; |
25 | 24 | import java.util.List; |
@@ -86,93 +85,84 @@ public Expression right() { |
86 | 85 | return right; |
87 | 86 | } |
88 | 87 |
|
89 | | - protected abstract class SimilarityEvaluatorFactory implements EvalOperator.ExpressionEvaluator.Factory { |
| 88 | + @FunctionalInterface |
| 89 | + public interface SimilarityEvaluatorFunction { |
| 90 | + float calculateSimilarity(float[] leftScratch, float[] rightScratch); |
| 91 | + } |
90 | 92 |
|
91 | | - protected final EvalOperator.ExpressionEvaluator.Factory left; |
92 | | - protected final EvalOperator.ExpressionEvaluator.Factory right; |
| 93 | + protected class SimilarityEvaluatorFactory implements EvalOperator.ExpressionEvaluator.Factory { |
93 | 94 |
|
94 | | - SimilarityEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory left, EvalOperator.ExpressionEvaluator.Factory right) { |
95 | | - this.left = left; |
96 | | - this.right = right; |
97 | | - } |
| 95 | + private final EvalOperator.ExpressionEvaluator.Factory left; |
| 96 | + private final EvalOperator.ExpressionEvaluator.Factory right; |
| 97 | + private final SimilarityEvaluatorFunction similarityFunction; |
98 | 98 |
|
99 | | - @Override |
100 | | - public EvalOperator.ExpressionEvaluator get(DriverContext context) { |
101 | | - return getSimilarityEvaluator(context); |
102 | | - } |
103 | | - |
104 | | - protected abstract SimilarityEvaluator getSimilarityEvaluator(DriverContext context); |
105 | | - } |
106 | | - |
107 | | - /** |
108 | | - * Evaluator for {@link CosineSimilarity}. Not generated and doesn’t extend from |
109 | | - * {@link AbstractMultivalueFunction.AbstractEvaluator} because it’s different from {@link org.elasticsearch.compute.ann.MvEvaluator} |
110 | | - * or scalar evaluators. |
111 | | - * <p> |
112 | | - * We can probably generalize to a common class or use its own annotation / evaluator template |
113 | | - */ |
114 | | - protected abstract static class SimilarityEvaluator implements EvalOperator.ExpressionEvaluator { |
115 | | - private final DriverContext context; |
116 | | - private final EvalOperator.ExpressionEvaluator left; |
117 | | - private final EvalOperator.ExpressionEvaluator right; |
118 | | - |
119 | | - SimilarityEvaluator(DriverContext context, EvalOperator.ExpressionEvaluator left, EvalOperator.ExpressionEvaluator right) { |
120 | | - this.context = context; |
| 99 | + SimilarityEvaluatorFactory( |
| 100 | + EvalOperator.ExpressionEvaluator.Factory left, |
| 101 | + EvalOperator.ExpressionEvaluator.Factory right, |
| 102 | + SimilarityEvaluatorFunction similarityFunction |
| 103 | + ) { |
121 | 104 | this.left = left; |
122 | 105 | this.right = right; |
| 106 | + this.similarityFunction = similarityFunction; |
123 | 107 | } |
124 | 108 |
|
125 | 109 | @Override |
126 | | - public final Block eval(Page page) { |
127 | | - try (FloatBlock leftBlock = (FloatBlock) left.eval(page); FloatBlock rightBlock = (FloatBlock) right.eval(page)) { |
128 | | - int positionCount = page.getPositionCount(); |
129 | | - if (positionCount == 0) { |
130 | | - return context.blockFactory().newConstantFloatBlockWith(0F, 0); |
| 110 | + public EvalOperator.ExpressionEvaluator get(DriverContext context) { |
| 111 | + return new EvalOperator.ExpressionEvaluator() { |
| 112 | + @Override |
| 113 | + public Block eval(Page page) { |
| 114 | + try ( |
| 115 | + FloatBlock leftBlock = (FloatBlock) left.get(context).eval(page); |
| 116 | + FloatBlock rightBlock = (FloatBlock) right.get(context).eval(page) |
| 117 | + ) { |
| 118 | + int positionCount = page.getPositionCount(); |
| 119 | + if (positionCount == 0) { |
| 120 | + return context.blockFactory().newConstantFloatBlockWith(0F, 0); |
| 121 | + } |
| 122 | + |
| 123 | + int dimensions = leftBlock.getValueCount(0); |
| 124 | + int dimsRight = rightBlock.getValueCount(0); |
| 125 | + assert dimensions == dimsRight |
| 126 | + : "Left and right vector must have the same value count, but got left: " + dimensions + ", right: " + dimsRight; |
| 127 | + float[] leftScratch = new float[dimensions]; |
| 128 | + float[] rightScratch = new float[dimensions]; |
| 129 | + try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) { |
| 130 | + for (int p = 0; p < positionCount; p++) { |
| 131 | + assert leftBlock.getValueCount(p) == dimensions |
| 132 | + : "Left vector must have the same value count for all positions, but got left: " |
| 133 | + + leftBlock.getValueCount(p) |
| 134 | + + ", expected: " |
| 135 | + + dimensions; |
| 136 | + assert rightBlock.getValueCount(p) == dimensions |
| 137 | + : "Left vector must have the same value count for all positions, but got left: " |
| 138 | + + rightBlock.getValueCount(p) |
| 139 | + + ", expected: " |
| 140 | + + dimensions; |
| 141 | + |
| 142 | + readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch); |
| 143 | + readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch); |
| 144 | + float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch); |
| 145 | + builder.appendDouble(result); |
| 146 | + } |
| 147 | + return builder.build().asBlock(); |
| 148 | + } |
| 149 | + } |
131 | 150 | } |
132 | 151 |
|
133 | | - int dimensions = leftBlock.getValueCount(0); |
134 | | - int dimsRight = rightBlock.getValueCount(0); |
135 | | - assert dimensions == dimsRight |
136 | | - : "Left and right vector must have the same value count, but got left: " + dimensions + ", right: " + dimsRight; |
137 | | - float[] leftScratch = new float[dimensions]; |
138 | | - float[] rightScratch = new float[dimensions]; |
139 | | - try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) { |
140 | | - for (int p = 0; p < positionCount; p++) { |
141 | | - assert leftBlock.getValueCount(p) == dimensions |
142 | | - : "Left vector must have the same value count for all positions, but got left: " |
143 | | - + leftBlock.getValueCount(p) |
144 | | - + ", expected: " |
145 | | - + dimensions; |
146 | | - assert rightBlock.getValueCount(p) == dimensions |
147 | | - : "Left vector must have the same value count for all positions, but got left: " |
148 | | - + rightBlock.getValueCount(p) |
149 | | - + ", expected: " |
150 | | - + dimensions; |
151 | | - |
152 | | - readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch); |
153 | | - readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch); |
154 | | - float result = calculateSimilarity(leftScratch, rightScratch); |
155 | | - builder.appendDouble(result); |
156 | | - } |
157 | | - return builder.build().asBlock(); |
| 152 | + @Override |
| 153 | + public void close() {} |
| 154 | + |
| 155 | + @Override |
| 156 | + public String toString() { |
| 157 | + return "ExpressionEvaluator[left=" + left + ", right=" + right + "]"; |
158 | 158 | } |
159 | | - } |
| 159 | + }; |
160 | 160 | } |
161 | 161 |
|
162 | | - protected abstract float calculateSimilarity(float[] leftScratch, float[] rightScratch); |
163 | | - |
164 | 162 | private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) { |
165 | 163 | for (int i = 0; i < dimensions; i++) { |
166 | 164 | scratch[i] = block.getFloat(position + i); |
167 | 165 | } |
168 | 166 | } |
169 | | - |
170 | | - @Override |
171 | | - public final String toString() { |
172 | | - return getClass().getSimpleName() + "=" + left + ", right=" + right + "]"; |
173 | | - } |
174 | | - |
175 | | - @Override |
176 | | - public void close() {} |
177 | 167 | } |
178 | 168 | } |
0 commit comments