-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Implement v_magnitude function #132765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement v_magnitude function #132765
Changes from 36 commits
79f6ba4
0a80914
8a9967f
c09f349
bf661a5
035b14d
722b2f9
e1e4f96
f9035d6
3f7dfb7
d7bf82a
e9c5d0c
6feed0b
427c703
cc5f4f7
a5091ad
ad880d9
68d79a4
f620f4b
f6d333c
2ce2cf8
4c024b0
1478fdc
50dce30
1f39b8f
3743261
3ce2f53
08ecefd
98c3d9b
743804c
1691bff
5d1effc
bbd72b8
612ca48
2a9a64d
eb3d0f6
84df3be
35b5f6c
c542420
db3e76d
b7e9933
4c7657a
cbd7d33
d1b9fbe
0828de8
e24f914
fd46a72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| pr: 132765 | ||
| summary: Implement `v_magnitude` function | ||
| area: ES|QL | ||
| type: feature | ||
| issues: [] | ||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| # Tests for v_magnitude scalar function | ||
|
|
||
| magnitudeWithVectorField | ||
| required_capability: magnitude_scalar_vector_function | ||
|
|
||
| // tag::vector-magnitude[] | ||
| from colors | ||
| | eval magnitude = v_magnitude(rgb_vector) | ||
| | sort magnitude desc, color asc | ||
| // end::vector-magnitude[] | ||
| | limit 10 | ||
| | keep color, magnitude | ||
| ; | ||
|
|
||
| // tag::vector-magnitude-result[] | ||
| color:text | magnitude:double | ||
| white | 441.6729431152344 | ||
| snow | 435.9185791015625 | ||
| azure | 433.1858825683594 | ||
| ivory | 433.1858825683594 | ||
| mint cream | 433.0704345703125 | ||
| sea shell | 426.25579833984375 | ||
| honeydew | 424.5291442871094 | ||
| old lace | 420.6352233886719 | ||
| corn silk | 418.2451477050781 | ||
| linen | 415.93267822265625 | ||
| // end::vector-magnitude-result[] | ||
| ; | ||
|
|
||
| magnitudeAsPartOfExpression | ||
| required_capability: magnitude_scalar_vector_function | ||
|
|
||
| from colors | ||
| | eval score = round((1 + v_magnitude(rgb_vector) / 2), 3) | ||
| | sort score desc, color asc | ||
| | limit 10 | ||
| | keep color, score | ||
| ; | ||
|
|
||
| color:text | score:double | ||
| white | 221.836 | ||
| snow | 218.959 | ||
| azure | 217.593 | ||
| ivory | 217.593 | ||
| mint cream | 217.535 | ||
| sea shell | 214.128 | ||
| honeydew | 213.265 | ||
| old lace | 211.318 | ||
| corn silk | 210.123 | ||
| linen | 208.966 | ||
| ; | ||
|
|
||
| magnitudeWithLiteralVectors | ||
| required_capability: magnitude_scalar_vector_function | ||
|
|
||
| row a = 1 | ||
| | eval magnitude = round(v_magnitude([1, 2, 3]), 3) | ||
| | keep magnitude | ||
| ; | ||
|
|
||
| magnitude:double | ||
| 3.742 | ||
| ; | ||
|
|
||
| magnitudeWithStats | ||
| required_capability: magnitude_scalar_vector_function | ||
|
|
||
| from colors | ||
| | eval magnitude = round(v_magnitude(rgb_vector), 3) | ||
| | stats avg = round(avg(magnitude), 3), min = min(magnitude), max = max(magnitude) | ||
| ; | ||
|
|
||
| avg:double | min:double | max:double | ||
| 313.692 | 0.0 | 441.673 | ||
| ; | ||
svilen-mihaylov-elastic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.esql.expression.function.vector; | ||
|
|
||
| import org.apache.lucene.util.VectorUtil; | ||
| import org.elasticsearch.common.io.stream.NamedWriteableRegistry; | ||
| import org.elasticsearch.common.io.stream.StreamInput; | ||
| import org.elasticsearch.compute.data.Block; | ||
| import org.elasticsearch.compute.data.FloatBlock; | ||
| import org.elasticsearch.compute.data.Page; | ||
| import org.elasticsearch.compute.operator.DriverContext; | ||
| import org.elasticsearch.compute.operator.EvalOperator; | ||
| import org.elasticsearch.xpack.esql.core.expression.Expression; | ||
| import org.elasticsearch.xpack.esql.core.expression.FoldContext; | ||
| import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; | ||
| import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction; | ||
| import org.elasticsearch.xpack.esql.core.tree.NodeInfo; | ||
| import org.elasticsearch.xpack.esql.core.tree.Source; | ||
| import org.elasticsearch.xpack.esql.core.type.DataType; | ||
| import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; | ||
| import org.elasticsearch.xpack.esql.expression.function.Example; | ||
| import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; | ||
| import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; | ||
| import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; | ||
| import org.elasticsearch.xpack.esql.expression.function.Param; | ||
|
|
||
| import java.io.IOException; | ||
|
|
||
| import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; | ||
| import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; | ||
| import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; | ||
|
|
||
| public class Magnitude extends UnaryScalarFunction implements EvaluatorMapper, VectorFunction { | ||
|
|
||
| public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Hamming", Magnitude::new); | ||
| static final ScalarEvaluatorFunction SCALAR_FUNCTION = Magnitude::calculateScalar; | ||
|
|
||
| @FunctionInfo( | ||
| returnType = "double", | ||
| preview = true, | ||
| description = "Calculates the magnitude of a dense_vector.", | ||
| examples = { @Example(file = "vector-magnitude", tag = "vector-magnitude") }, | ||
| appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) } | ||
| ) | ||
| public Magnitude( | ||
| Source source, | ||
| @Param(name = "input", type = { "dense_vector" }, description = "dense_vector for which to compute the magnitude") Expression input | ||
| ) { | ||
| super(source, input); | ||
| } | ||
|
|
||
| private Magnitude(StreamInput in) throws IOException { | ||
| super(in); | ||
| } | ||
|
|
||
| @Override | ||
| protected UnaryScalarFunction replaceChild(Expression newChild) { | ||
| return new Magnitude(source(), newChild); | ||
| } | ||
|
|
||
| @Override | ||
| protected NodeInfo<? extends Expression> info() { | ||
| return NodeInfo.create(this, Magnitude::new, field()); | ||
| } | ||
|
|
||
| @Override | ||
| public String getWriteableName() { | ||
| return ENTRY.name; | ||
| } | ||
|
|
||
| public static float calculateScalar(float[] scratch) { | ||
| return (float) Math.sqrt(VectorUtil.dotProduct(scratch, scratch)); | ||
svilen-mihaylov-elastic marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| @Override | ||
| public DataType dataType() { | ||
| return DataType.DOUBLE; | ||
| } | ||
|
|
||
| @Override | ||
| protected TypeResolution resolveType() { | ||
| if (childrenResolved() == false) { | ||
| return new TypeResolution("Unresolved children"); | ||
| } | ||
|
|
||
| return isNotNull(field(), sourceText(), TypeResolutions.ParamOrdinal.FIRST).and( | ||
| isType(field(), dt -> dt == DENSE_VECTOR, sourceText(), TypeResolutions.ParamOrdinal.FIRST, "dense_vector") | ||
| ); | ||
| } | ||
|
|
||
| /** | ||
| * Functional interface for evaluating the scalar value of the underlying float array. | ||
| */ | ||
| @FunctionalInterface | ||
| public interface ScalarEvaluatorFunction { | ||
| float calculateScalar(float[] scratch); | ||
| } | ||
|
|
||
| @Override | ||
| public Object fold(FoldContext ctx) { | ||
carlosdelest marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return EvaluatorMapper.super.fold(source(), ctx); | ||
| } | ||
|
|
||
| @Override | ||
| public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) { | ||
| return new ScalarEvaluatorFactory(toEvaluator.apply(field()), SCALAR_FUNCTION, getClass().getSimpleName() + "Evaluator"); | ||
| } | ||
|
|
||
| private record ScalarEvaluatorFactory( | ||
| EvalOperator.ExpressionEvaluator.Factory child, | ||
| ScalarEvaluatorFunction scalarFunction, | ||
| String evaluatorName | ||
| ) implements EvalOperator.ExpressionEvaluator.Factory { | ||
|
|
||
| @Override | ||
| public EvalOperator.ExpressionEvaluator get(DriverContext context) { | ||
| // TODO check whether to use this custom evaluator or reuse / define an existing one | ||
| return new EvalOperator.ExpressionEvaluator() { | ||
| @Override | ||
| public Block eval(Page page) { | ||
| try (FloatBlock block = (FloatBlock) child.get(context).eval(page);) { | ||
| int positionCount = page.getPositionCount(); | ||
| int dimensions = 0; | ||
| // Get the first non-empty vector to calculate the dimension | ||
| for (int p = 0; p < positionCount; p++) { | ||
| if (block.getValueCount(p) != 0) { | ||
| dimensions = block.getValueCount(p); | ||
| break; | ||
| } | ||
| } | ||
| if (dimensions == 0) { | ||
| return context.blockFactory().newConstantFloatBlockWith(0F, 0); | ||
| } | ||
|
|
||
| float[] scratch = new float[dimensions]; | ||
| try (var builder = context.blockFactory().newDoubleBlockBuilder(positionCount * dimensions)) { | ||
| for (int p = 0; p < positionCount; p++) { | ||
| int dims = block.getValueCount(p); | ||
| if (dims == 0) { | ||
| // A null value for the vector, by default append 0 as result. | ||
svilen-mihaylov-elastic marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| builder.appendNull(); | ||
| continue; | ||
| } | ||
| readFloatArray(block, block.getFirstValueIndex(p), dimensions, scratch); | ||
| float result = scalarFunction.calculateScalar(scratch); | ||
| builder.appendDouble(result); | ||
| } | ||
| return builder.build(); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| @Override | ||
| public String toString() { | ||
| return evaluatorName() + "[child=" + child + "]"; | ||
| } | ||
|
|
||
| @Override | ||
| public void close() {} | ||
| }; | ||
| } | ||
|
|
||
| private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) { | ||
| for (int i = 0; i < dimensions; i++) { | ||
| scratch[i] = block.getFloat(position + i); | ||
| } | ||
| } | ||
|
|
||
| @Override | ||
| public String toString() { | ||
| return evaluatorName() + "[child=" + child + "]"; | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.