Skip to content

Commit 6cc2115

Browse files
committed
Rename to v_cosine, add analyzer tests for implicit casting
1 parent 85e2426 commit 6cc2115

File tree

5 files changed

+32
-6
lines changed

5 files changed

+32
-6
lines changed

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
3535
public static Iterable<Object[]> parameters() throws Exception {
3636
List<Object[]> params = new ArrayList<>();
3737

38-
params.add(new Object[] { "v_cosine_similarity", VectorSimilarityFunction.COSINE });
38+
params.add(new Object[] { "v_cosine", VectorSimilarityFunction.COSINE });
3939

4040
return params;
4141
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ private static FunctionDefinition[][] snapshotFunctions() {
489489
def(StGeohex.class, StGeohex::new, "st_geohex"),
490490
def(StGeohexToLong.class, StGeohexToLong::new, "st_geohex_to_long"),
491491
def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"),
492-
def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine_similarity") } };
492+
def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine") } };
493493
}
494494

495495
public EsqlFunctionRegistry snapshotRegistry() {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
/**
3333
* Base class for vector similarity functions, which compute a similarity score between two dense vectors
3434
*/
35-
abstract class VectorSimilarityFunction extends EsqlScalarFunction implements VectorFunction {
35+
public abstract class VectorSimilarityFunction extends EsqlScalarFunction implements VectorFunction {
3636

3737
private final Expression left;
3838
private final Expression right;

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat;
5858
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring;
5959
import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
60+
import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction;
6061
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
6162
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
6263
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
@@ -92,6 +93,7 @@
9293
import java.time.Period;
9394
import java.util.ArrayList;
9495
import java.util.List;
96+
import java.util.Locale;
9597
import java.util.Map;
9698
import java.util.Set;
9799
import java.util.function.Function;
@@ -123,6 +125,7 @@
123125
import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME;
124126
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS;
125127
import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD;
128+
import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
126129
import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
127130
import static org.elasticsearch.xpack.esql.core.type.DataType.LONG;
128131
import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED;
@@ -2370,7 +2373,7 @@ public void testImplicitCasting() {
23702373
assertThat(e.getMessage(), containsString("[+] has arguments with incompatible types [datetime] and [datetime]"));
23712374
}
23722375

2373-
public void testDenseVectorImplicitCasting() {
2376+
public void testDenseVectorImplicitCastingKnn() {
23742377
assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
23752378
Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors"));
23762379

@@ -2387,6 +2390,29 @@ public void testDenseVectorImplicitCasting() {
23872390
assertThat(queryVector.value(), equalTo(List.of(0.342, 0.164, 0.234)));
23882391
}
23892392

2393+
public void testDenseVectorImplicitCastingSimilarityFunctions() {
2394+
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
2395+
checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [0.342, 0.164, 0.234])");
2396+
}
2397+
}
2398+
2399+
private void checkDenseVectorImplicitCastingSimilarityFunction(String similarityFunction) {
2400+
var plan = analyze(String.format(Locale.ROOT, """
2401+
from test | eval similarity = %s
2402+
""", similarityFunction), "mapping-dense_vector.json");
2403+
2404+
var limit = as(plan, Limit.class);
2405+
var eval = as(limit.child(), Eval.class);
2406+
var alias = as(eval.fields().get(0), Alias.class);
2407+
assertEquals("similarity", alias.name());
2408+
var similarity = as(alias.child(), VectorSimilarityFunction.class);
2409+
var left = as(similarity.left(), FieldAttribute.class);
2410+
assertEquals("vector", left.name());
2411+
var right = as(similarity.right(), Literal.class);
2412+
assertThat(right.dataType(), is(DENSE_VECTOR));
2413+
assertThat(right.value(), equalTo(List.of(0.342, 0.164, 0.234)));;
2414+
}
2415+
23902416
public void testRateRequiresCounterTypes() {
23912417
assumeTrue("rate requires snapshot builds", Build.current().isSnapshot());
23922418
Analyzer analyzer = analyzer(tsdbIndexResolution());

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2238,8 +2238,8 @@ private void checkFullTextFunctionsInStats(String functionInvocation) {
22382238

22392239
public void testVectorSimilarityFunctionsNullArgs() throws Exception {
22402240
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
2241-
checkVectorSimilarityFunctionsNullArgs("v_cosine_similarity(null, vector)", "first");
2242-
checkVectorSimilarityFunctionsNullArgs("v_cosine_similarity(vector, null)", "second");
2241+
checkVectorSimilarityFunctionsNullArgs("v_cosine(null, vector)", "first");
2242+
checkVectorSimilarityFunctionsNullArgs("v_cosine(vector, null)", "second");
22432243
}
22442244
}
22452245

0 commit comments

Comments
 (0)