Skip to content

Commit 7ce3b87

Browse files
committed
Coerce return types for DIVIDE and MOD UDFs
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent acb5a8b commit 7ce3b87

File tree

7 files changed

+177
-94
lines changed

7 files changed

+177
-94
lines changed

core/src/main/java/org/opensearch/sql/calcite/udf/datetimeUDF/TimeAddSubFunction.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,20 @@
55

66
package org.opensearch.sql.calcite.udf.datetimeUDF;
77

8+
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_DATE;
9+
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIME;
10+
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIMESTAMP;
811
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.restoreFunctionProperties;
912
import static org.opensearch.sql.calcite.utils.datetime.DateTimeApplyUtils.transferInputToExprValue;
1013
import static org.opensearch.sql.expression.datetime.DateTimeFunctions.exprAddTime;
1114
import static org.opensearch.sql.expression.datetime.DateTimeFunctions.exprSubTime;
1215

16+
import org.apache.calcite.rel.type.RelDataType;
17+
import org.apache.calcite.sql.type.SqlReturnTypeInference;
1318
import org.apache.calcite.sql.type.SqlTypeName;
19+
import org.opensearch.sql.calcite.type.ExprSqlType;
1420
import org.opensearch.sql.calcite.udf.UserDefinedFunction;
21+
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
1522
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
1623
import org.opensearch.sql.data.model.ExprTimeValue;
1724
import org.opensearch.sql.data.model.ExprValue;
@@ -44,4 +51,35 @@ public Object eval(Object... args) {
4451
return result.valueForCalcite();
4552
}
4653
}
54+
55+
/**
56+
* ADDTIME and SUBTIME has special return type maps: (DATE/TIMESTAMP, DATE/TIMESTAMP/TIME) ->
57+
* TIMESTAMP (TIME, DATE/TIMESTAMP/TIME) -> TIME Therefore, we create a special return type
58+
* inference for them.
59+
*/
60+
public static SqlReturnTypeInference getReturnTypeForTimeAddSub() {
61+
return opBinding -> {
62+
RelDataType operandType0 = opBinding.getOperandType(0);
63+
if (operandType0 instanceof ExprSqlType) {
64+
OpenSearchTypeFactory.ExprUDT exprUDT = ((ExprSqlType) operandType0).getUdt();
65+
if (exprUDT == EXPR_DATE || exprUDT == EXPR_TIMESTAMP) {
66+
return UserDefinedFunctionUtils.nullableTimestampUDT;
67+
} else if (exprUDT == EXPR_TIME) {
68+
return UserDefinedFunctionUtils.nullableTimeUDT;
69+
} else {
70+
throw new IllegalArgumentException("Unsupported UDT type");
71+
}
72+
}
73+
SqlTypeName typeName = operandType0.getSqlTypeName();
74+
return switch (typeName) {
75+
case DATE, TIMESTAMP ->
76+
// Return TIMESTAMP
77+
UserDefinedFunctionUtils.nullableTimestampUDT;
78+
case TIME ->
79+
// Return TIME
80+
UserDefinedFunctionUtils.nullableTimeUDT;
81+
default -> throw new IllegalArgumentException("Unsupported type: " + typeName);
82+
};
83+
};
84+
}
4785
}

core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/DivideFunction.java

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,28 @@
66
package org.opensearch.sql.calcite.udf.mathUDF;
77

88
import org.opensearch.sql.calcite.udf.UserDefinedFunction;
9+
import org.opensearch.sql.calcite.utils.MathUtils;
10+
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
911

1012
public class DivideFunction implements UserDefinedFunction {
1113

1214
@Override
1315
public Object eval(Object... args) {
14-
double dividend = ((Number) args[0]).doubleValue();
15-
double divisor = ((Number) args[1]).doubleValue();
16-
return dividend / divisor;
16+
if (UserDefinedFunctionUtils.containsNull(args)) {
17+
return null;
18+
}
19+
20+
Number dividend = (Number) args[0];
21+
Number divisor = (Number) args[1];
22+
23+
if (Math.abs(divisor.doubleValue()) < MathUtils.EPSILON) {
24+
return null;
25+
}
26+
27+
double result = dividend.doubleValue() / divisor.doubleValue();
28+
if (MathUtils.isIntegral(dividend) && MathUtils.isIntegral(divisor)) {
29+
return MathUtils.coerceToWidestIntegralType(dividend, divisor, (long) result);
30+
}
31+
return MathUtils.coerceToWidestFloatingType(dividend, divisor, result);
1732
}
1833
}

core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/ModFunction.java

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import java.math.BigDecimal;
99
import org.opensearch.sql.calcite.udf.UserDefinedFunction;
10+
import org.opensearch.sql.calcite.utils.MathUtils;
1011

1112
/**
1213
* Calculate the remainder of x divided by y<br>
@@ -32,33 +33,22 @@ public Object eval(Object... args) {
3233
arg0.getClass().getSimpleName(), arg1.getClass().getSimpleName()));
3334
}
3435

35-
// TODO: This precision check is arbitrary.
36-
if (Math.abs(num1.doubleValue()) < 0.0000001) {
36+
if (Math.abs(num1.doubleValue()) < MathUtils.EPSILON) {
3737
return null;
3838
}
3939

40-
if (isIntegral(num0) && isIntegral(num1)) {
40+
if (MathUtils.isIntegral(num0) && MathUtils.isIntegral(num1)) {
4141
long l0 = num0.longValue();
4242
long l1 = num1.longValue();
4343
// It returns negative values when l0 is negative
4444
long result = l0 % l1;
4545
// Return the wider type between l0 and l1
46-
if (num0 instanceof Long || num1 instanceof Long) {
47-
return result;
48-
}
49-
return (int) result;
46+
return MathUtils.coerceToWidestIntegralType(num0, num1, result);
5047
}
5148

5249
BigDecimal b0 = new BigDecimal(num0.toString());
5350
BigDecimal b1 = new BigDecimal(num1.toString());
5451
BigDecimal result = b0.remainder(b1);
55-
if (num0 instanceof Double || num1 instanceof Double) {
56-
return result.doubleValue();
57-
}
58-
return result.floatValue();
59-
}
60-
61-
private boolean isIntegral(Number n) {
62-
return n instanceof Byte || n instanceof Short || n instanceof Integer || n instanceof Long;
52+
return MathUtils.coerceToWidestFloatingType(num0, num1, result.doubleValue());
6353
}
6454
}

core/src/main/java/org/opensearch/sql/calcite/utils/BuiltinFunctionUtils.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package org.opensearch.sql.calcite.utils;
77

88
import static java.lang.Math.E;
9-
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.*;
109
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.getLegacyTypeName;
1110
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.*;
1211

@@ -139,7 +138,8 @@ static SqlOperator translate(String op) {
139138
case "*":
140139
return SqlStdOperatorTable.MULTIPLY;
141140
case "/":
142-
return TransferUserDefinedFunction(DivideFunction.class, "/", ReturnTypes.DOUBLE);
141+
return TransferUserDefinedFunction(
142+
DivideFunction.class, "/", ReturnTypes.QUOTIENT_NULLABLE);
143143
// Built-in String Functions
144144
case "ASCII":
145145
return SqlStdOperatorTable.ASCII;
@@ -217,8 +217,7 @@ static SqlOperator translate(String op) {
217217
// The MOD function in PPL supports floating-point parameters, e.g., MOD(5.5, 2) = 1.5,
218218
// MOD(3.1, 2.1) = 1.1,
219219
// whereas SqlStdOperatorTable.MOD supports only integer / long parameters.
220-
return TransferUserDefinedFunction(
221-
ModFunction.class, "MOD", getLeastRestrictiveReturnTypeAmongArgsAt(List.of(0, 1)));
220+
return TransferUserDefinedFunction(ModFunction.class, "MOD", ReturnTypes.LEAST_RESTRICTIVE);
222221
case "PI":
223222
return SqlStdOperatorTable.PI;
224223
case "POW", "POWER":
@@ -266,9 +265,7 @@ static SqlOperator translate(String op) {
266265
DateAddSubFunction.class, "DATE_SUB", timestampInference);
267266
case "ADDTIME", "SUBTIME":
268267
return TransferUserDefinedFunction(
269-
TimeAddSubFunction.class,
270-
capitalOP,
271-
UserDefinedFunctionUtils.getReturnTypeForTimeAddSub());
268+
TimeAddSubFunction.class, capitalOP, TimeAddSubFunction.getReturnTypeForTimeAddSub());
272269
case "DAY_OF_WEEK", "DAYOFWEEK":
273270
return TransferUserDefinedFunction(
274271
DayOfWeekFunction.class, capitalOP, INTEGER_FORCE_NULLABLE);
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.utils;
7+
8+
public class MathUtils {
9+
public static final double EPSILON = 0.0000001;
10+
11+
public static boolean isIntegral(Number n) {
12+
return n instanceof Byte || n instanceof Short || n instanceof Integer || n instanceof Long;
13+
}
14+
15+
/**
16+
* Converts a long value to the least restrictive integral type based on the types of two input
17+
* numbers.
18+
*
19+
* <p>This is useful when performing operations like division or modulo and you want to preserve
20+
* the most appropriate type (e.g., int vs long).
21+
*
22+
* @param a one operand involved in the operation
23+
* @param b another operand involved in the operation
24+
* @param value the result to convert to the least restrictive integral type
25+
* @return the value converted to Byte, Short, Integer, or Long
26+
*/
27+
public static Number coerceToWidestIntegralType(Number a, Number b, long value) {
28+
if (a instanceof Long || b instanceof Long) {
29+
return value;
30+
} else if (a instanceof Integer || b instanceof Integer) {
31+
return (int) value;
32+
} else if (a instanceof Short || b instanceof Short) {
33+
return (short) value;
34+
} else {
35+
return (byte) value;
36+
}
37+
}
38+
39+
/**
40+
* Converts a double value to the least restrictive floating type based on the types of two input
41+
*
42+
* @param a one operand involved in the operation
43+
* @param b another operand involved in the operation
44+
* @param value the result to convert to the least restrictive floating type
45+
* @return the value converted to Float or Double
46+
*/
47+
public static Number coerceToWidestFloatingType(Number a, Number b, double value) {
48+
if (a instanceof Double || b instanceof Double) {
49+
return value;
50+
} else {
51+
return (float) value;
52+
}
53+
}
54+
}

core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -104,37 +104,6 @@ public static SqlOperator TransferUserDefinedFunction(
104104
udfFunction);
105105
}
106106

107-
/**
108-
* Infer return argument type as the widest return type among arguments as specified positions.
109-
* E.g. (Integer, Long) -> Long; (Double, Float, SHORT) -> Double
110-
*
111-
* @param positions positions where the return type should be inferred from
112-
* @param nullable whether the returned value is nullable
113-
* @return The type inference
114-
*/
115-
public static SqlReturnTypeInference getLeastRestrictiveReturnTypeAmongArgsAt(
116-
List<Integer> positions) {
117-
return opBinding -> {
118-
RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
119-
List<RelDataType> types = new ArrayList<>();
120-
121-
for (int position : positions) {
122-
if (position < 0 || position >= opBinding.getOperandCount()) {
123-
throw new IllegalArgumentException("Invalid argument position: " + position);
124-
}
125-
types.add(opBinding.getOperandType(position));
126-
}
127-
128-
RelDataType widerType = typeFactory.leastRestrictive(types);
129-
if (widerType == null) {
130-
throw new IllegalArgumentException(
131-
"Cannot determine a common type for the given positions.");
132-
}
133-
134-
return widerType;
135-
};
136-
}
137-
138107
static SqlReturnTypeInference getReturnTypeInferenceForArray() {
139108
return opBinding -> {
140109
RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
@@ -150,37 +119,6 @@ static SqlReturnTypeInference getReturnTypeInferenceForArray() {
150119
};
151120
}
152121

153-
/**
154-
* ADDTIME and SUBTIME has special return type maps: (DATE/TIMESTAMP, DATE/TIMESTAMP/TIME) ->
155-
* TIMESTAMP (TIME, DATE/TIMESTAMP/TIME) -> TIME Therefore, we create a special return type
156-
* inference for them.
157-
*/
158-
static SqlReturnTypeInference getReturnTypeForTimeAddSub() {
159-
return opBinding -> {
160-
RelDataType operandType0 = opBinding.getOperandType(0);
161-
if (operandType0 instanceof ExprSqlType) {
162-
ExprUDT exprUDT = ((ExprSqlType) operandType0).getUdt();
163-
if (exprUDT == EXPR_DATE || exprUDT == EXPR_TIMESTAMP) {
164-
return nullableTimestampUDT;
165-
} else if (exprUDT == EXPR_TIME) {
166-
return nullableTimeUDT;
167-
} else {
168-
throw new IllegalArgumentException("Unsupported UDT type");
169-
}
170-
}
171-
SqlTypeName typeName = operandType0.getSqlTypeName();
172-
return switch (typeName) {
173-
case DATE, TIMESTAMP ->
174-
// Return TIMESTAMP
175-
nullableTimestampUDT;
176-
case TIME ->
177-
// Return TIME
178-
nullableTimeUDT;
179-
default -> throw new IllegalArgumentException("Unsupported type: " + typeName);
180-
};
181-
};
182-
}
183-
184122
static List<Integer> transferStringExprToDateValue(String timeExpr) {
185123
try {
186124
if (timeExpr.contains(":")) {

0 commit comments

Comments
 (0)