Skip to content

Commit 191b1a3

Browse files
Copilotzclllyybb
andcommitted
Add FE constant folding for cosine_similarity and use test{sql,exception} pattern
Co-authored-by: zclllyybb <61408379+zclllyybb@users.noreply.github.com>
1 parent ee40779 commit 191b1a3

File tree

3 files changed

+108
-27
lines changed

3 files changed

+108
-27
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import org.apache.doris.nereids.exceptions.NotSupportedException;
2121
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
22+
import org.apache.doris.nereids.trees.expressions.functions.executable.ArrayArithmetic;
2223
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeAcquire;
2324
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeArithmetic;
2425
import org.apache.doris.nereids.trees.expressions.functions.executable.DateTimeExtractAndTransform;
@@ -177,6 +178,7 @@ private void registerFunctions() {
177178
}
178179
ImmutableMultimap.Builder<String, Method> mapBuilder = new ImmutableMultimap.Builder<>();
179180
List<Class<?>> classes = ImmutableList.of(
181+
ArrayArithmetic.class,
180182
DateTimeAcquire.class,
181183
DateTimeExtractAndTransform.class,
182184
DateLiteral.class,
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.trees.expressions.functions.executable;
19+
20+
import org.apache.doris.nereids.exceptions.AnalysisException;
21+
import org.apache.doris.nereids.trees.expressions.ExecFunction;
22+
import org.apache.doris.nereids.trees.expressions.Expression;
23+
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
24+
import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
25+
import org.apache.doris.nereids.trees.expressions.literal.Literal;
26+
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
27+
28+
import java.util.List;
29+
30+
/**
31+
* Executable functions for array operations.
32+
*/
33+
public class ArrayArithmetic {
34+
35+
/**
36+
* Compute cosine similarity between two float arrays.
37+
* cosine_similarity(x, y) = dot(x, y) / (||x|| * ||y||)
38+
*/
39+
@ExecFunction(name = "cosine_similarity")
40+
public static Expression cosineSimilarity(ArrayLiteral array1, ArrayLiteral array2) {
41+
List<Literal> items1 = array1.getValue();
42+
List<Literal> items2 = array2.getValue();
43+
44+
// Check for null elements
45+
for (Literal item : items1) {
46+
if (item instanceof NullLiteral) {
47+
throw new AnalysisException("function cosine_similarity cannot have null");
48+
}
49+
}
50+
for (Literal item : items2) {
51+
if (item instanceof NullLiteral) {
52+
throw new AnalysisException("function cosine_similarity cannot have null");
53+
}
54+
}
55+
56+
// Check array sizes
57+
if (items1.size() != items2.size()) {
58+
throw new AnalysisException("function cosine_similarity have different input element sizes of array: "
59+
+ items1.size() + " and " + items2.size());
60+
}
61+
62+
// Handle empty arrays
63+
if (items1.isEmpty()) {
64+
return new FloatLiteral(0.0f);
65+
}
66+
67+
// Compute dot product and squared norms
68+
double dotProd = 0.0;
69+
double squaredX = 0.0;
70+
double squaredY = 0.0;
71+
72+
for (int i = 0; i < items1.size(); i++) {
73+
double x = ((Number) items1.get(i).getValue()).doubleValue();
74+
double y = ((Number) items2.get(i).getValue()).doubleValue();
75+
dotProd += x * y;
76+
squaredX += x * x;
77+
squaredY += y * y;
78+
}
79+
80+
// Handle zero vectors
81+
if (squaredX == 0.0 || squaredY == 0.0) {
82+
return new FloatLiteral(0.0f);
83+
}
84+
85+
float result = (float) (dotProd / Math.sqrt(squaredX * squaredY));
86+
return new FloatLiteral(result);
87+
}
88+
}

regression-test/suites/query_p0/sql_functions/array_functions/test_array_distance_functions.groovy

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,37 +51,32 @@ suite("test_array_distance_functions") {
5151
}
5252

5353
// abnormal test cases
54-
try {
54+
test {
5555
sql "SELECT l2_distance([0, 0], [1])"
56-
} catch (Exception ex) {
57-
assert("${ex}".contains("function l2_distance have different input element sizes"))
56+
exception "function l2_distance have different input element sizes"
5857
}
5958

60-
try {
59+
test {
6160
sql "SELECT cosine_distance([NULL], [NULL, NULL])"
62-
} catch (Exception ex) {
63-
assert("${ex}".contains("function cosine_distance cannot have null"))
61+
exception "function cosine_distance cannot have null"
6462
}
6563

6664
// Test cases for the nullable array offset fix
6765
// These cases specifically test scenarios where absolute offsets might differ
6866
// but actual array sizes are the same (should pass) or different (should fail)
69-
try {
67+
test {
7068
sql "SELECT l1_distance([1.0, 2.0, 3.0], [4.0, 5.0])"
71-
} catch (Exception ex) {
72-
assert("${ex}".contains("function l1_distance have different input element sizes"))
69+
exception "function l1_distance have different input element sizes"
7370
}
7471

75-
try {
72+
test {
7673
sql "SELECT inner_product([1.0], [2.0, 3.0, 4.0])"
77-
} catch (Exception ex) {
78-
assert("${ex}".contains("function inner_product have different input element sizes"))
74+
exception "function inner_product have different input element sizes"
7975
}
8076

81-
try {
77+
test {
8278
sql "SELECT l1_distance([1, 2, 3], [0, NULL, 0])"
83-
} catch (Exception ex) {
84-
assert("${ex}".contains("function l1_distance cannot have null"))
79+
exception "function l1_distance cannot have null"
8580
}
8681

8782
// Edge case: empty arrays should work
@@ -149,29 +144,25 @@ suite("test_array_distance_functions") {
149144
}
150145

151146
// Test array with NULL element: should throw exception
152-
try {
147+
test {
153148
sql "SELECT cosine_similarity([1, NULL, 3], [4, 5, 6])"
154-
} catch (Exception ex) {
155-
assert("${ex}".contains("function cosine_similarity cannot have null"))
149+
exception "function cosine_similarity cannot have null"
156150
}
157151

158-
try {
152+
test {
159153
sql "SELECT cosine_similarity([1, 2, 3], [4, NULL, 6])"
160-
} catch (Exception ex) {
161-
assert("${ex}".contains("function cosine_similarity cannot have null"))
154+
exception "function cosine_similarity cannot have null"
162155
}
163156

164157
// Test different array sizes: should throw exception
165-
try {
158+
test {
166159
sql "SELECT cosine_similarity([1, 2], [1, 2, 3])"
167-
} catch (Exception ex) {
168-
assert("${ex}".contains("function cosine_similarity have different input element sizes"))
160+
exception "function cosine_similarity have different input element sizes"
169161
}
170162

171-
try {
163+
test {
172164
sql "SELECT cosine_similarity([1, 2, 3, 4], [1, 2])"
173-
} catch (Exception ex) {
174-
assert("${ex}".contains("function cosine_similarity have different input element sizes"))
165+
exception "function cosine_similarity have different input element sizes"
175166
}
176167

177168
// Test large values

0 commit comments

Comments
 (0)