Skip to content

Commit 982676d

Browse files
Raaghav0facebook-github-bot
authored andcommitted
Check for nulls in array functions l2_squared and cosine_similarity
check for nulls in array functions l2_squared and cosine_similarity handling null in arrays to throw error instead updating doc reducing null check overhead
1 parent af29aa6 commit 982676d

File tree

3 files changed

+35
-6
lines changed

3 files changed

+35
-6
lines changed

presto-docs/src/main/sphinx/functions/math.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,22 @@ Mathematical Functions
4242

4343
.. function:: cosine_similarity(x, y) -> double
4444

45-
Returns the cosine similarity between the arrays ``x`` and ``y``::
45+
Returns the cosine similarity between the arrays ``x`` and ``y``.
46+
If the input arrays have different sizes or if the input arrays contain a null, the function throws user error::
4647

4748
SELECT cosine_similarity(ARRAY[1.2], ARRAY[2.0]); -- 1.0
4849

4950
.. function:: l2_squared(array(real), array(real)) -> real
5051

5152
Returns the squared `Euclidean distance <https://en.wikipedia.org/wiki/Euclidean_distance>`_ between the vectors represented as array(real).
52-
If the input arrays have different sizes, the function throws user error::
53+
If the input arrays have different sizes or if the input arrays contain a null, the function throws user error::
5354

5455
SELECT l2_squared(ARRAY[1.0], ARRAY[2.0]); -- 1.0
5556

5657
.. function:: l2_squared(array(double), array(double)) -> double
5758

5859
Returns the squared `Euclidean distance <https://en.wikipedia.org/wiki/Euclidean_distance>`_ between the vectors represented as array(double).
59-
If the input arrays have different sizes, the function throws user error::
60+
If the input arrays have different sizes or if the input arrays contain a null, the function throws user error::
6061

6162
SELECT l2_squared(ARRAY[1.0], ARRAY[2.0]); -- 1.0
6263

presto-main-base/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,6 +1632,11 @@ public static Double arrayCosineSimilarity(@SqlType("array(double)") Block leftA
16321632
INVALID_FUNCTION_ARGUMENT,
16331633
"Both array arguments need to have identical size");
16341634

1635+
checkCondition(
1636+
!(leftArray.mayHaveNull() || rightArray.mayHaveNull()),
1637+
INVALID_FUNCTION_ARGUMENT,
1638+
"Both arrays must not have nulls");
1639+
16351640
Double normLeftArray = array2Norm(leftArray);
16361641
Double normRightArray = array2Norm(rightArray);
16371642

@@ -1654,6 +1659,11 @@ public static long arrayL2Squared(@SqlType("array(real)") Block leftArray, @SqlT
16541659
INVALID_FUNCTION_ARGUMENT,
16551660
"Both array arguments need to have identical size");
16561661

1662+
checkCondition(
1663+
!(leftArray.mayHaveNull() || rightArray.mayHaveNull()),
1664+
INVALID_FUNCTION_ARGUMENT,
1665+
"Both arrays must not have nulls");
1666+
16571667
float sum = 0.0f;
16581668
for (int i = 0; i < leftArray.getPositionCount(); i++) {
16591669
float left = intBitsToFloat((int) leftArray.getInt(i));
@@ -1676,6 +1686,12 @@ public static double arrayL2SquaredDouble(
16761686
leftArray.getPositionCount() == rightArray.getPositionCount(),
16771687
INVALID_FUNCTION_ARGUMENT,
16781688
"Both array arguments need to have identical size");
1689+
1690+
checkCondition(
1691+
!(leftArray.mayHaveNull() || rightArray.mayHaveNull()),
1692+
INVALID_FUNCTION_ARGUMENT,
1693+
"Both arrays must not have nulls");
1694+
16791695
double sum = 0.0;
16801696
for (int i = 0; i < leftArray.getPositionCount(); i++) {
16811697
double left = DOUBLE.getDouble(leftArray, i);

presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestMathFunctions.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,9 +1392,7 @@ public void testArrayCosineSimilarity()
13921392
DOUBLE,
13931393
null);
13941394

1395-
assertFunction("cosine_similarity(array [1.0E0, null], array [1.0E0, 3.0E0])",
1396-
DOUBLE,
1397-
null);
1395+
assertInvalidFunction("cosine_similarity(array [1.0E0, null], array [1.0E0, 3.0E0])", "Both arrays must not have nulls");
13981396

13991397
assertInvalidFunction("cosine_similarity(array [], array [1.0E0, 3.0E0])", "Both array arguments need to have identical size");
14001398

@@ -1405,6 +1403,12 @@ public void testArrayCosineSimilarity()
14051403
assertFunction("cosine_similarity(array [], null)",
14061404
DOUBLE,
14071405
null);
1406+
1407+
assertInvalidFunction(
1408+
"cosine_similarity(array[1.0, null, 3.0], array[1.0, 2.0, 3.0])", "Both arrays must not have nulls");
1409+
1410+
assertInvalidFunction(
1411+
"cosine_similarity(array[1.0, 2.0, 3.0], array[1.0, null, 3.0])", "Both arrays must not have nulls");
14081412
}
14091413

14101414
@Test
@@ -1434,6 +1438,10 @@ public void testArrayL2Squared()
14341438
assertFunction(
14351439
"l2_squared(CAST(null AS array(real)), CAST(null AS array(real)))",
14361440
REAL, null);
1441+
assertInvalidFunction(
1442+
"l2_squared(array[REAL '1.0', null, REAL '3.0'], array[REAL '1.0', REAL '2.0', REAL '3.0'])", "Both arrays must not have nulls");
1443+
assertInvalidFunction(
1444+
"l2_squared(array[REAL '1.0', REAL '2.0', REAL '3.0'], array[REAL '1.0', null, REAL '3.0'])", "Both arrays must not have nulls");
14371445
}
14381446

14391447
@Test
@@ -1463,6 +1471,10 @@ public void testArrayL2SquaredDouble()
14631471
assertFunction(
14641472
"l2_squared(CAST(null AS array(double)), CAST(null AS array(double)))",
14651473
DOUBLE, null);
1474+
assertInvalidFunction(
1475+
"l2_squared(array[DOUBLE '1.0', null, DOUBLE '3.0'], array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'])", "Both arrays must not have nulls");
1476+
assertInvalidFunction(
1477+
"l2_squared(array[DOUBLE '1.0', DOUBLE '2.0', DOUBLE '3.0'], array[DOUBLE '1.0', null, DOUBLE '3.0'])", "Both arrays must not have nulls");
14661478
}
14671479

14681480
@Test

0 commit comments

Comments
 (0)