Skip to content

Commit ed93d20

Browse files
committed
[fix][core] Resolve the compatibility issue between double and decimal types in the lead/tag function
1 parent 2fddb55 commit ed93d20

File tree

3 files changed

+29
-11
lines changed

3 files changed

+29
-11
lines changed

core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,14 +1927,17 @@ protected LeadLagImplementor(boolean isLead) {
19271927
result.exitBlock();
19281928
BlockStatement thenBranch = thenBlock.toBlock();
19291929

1930-
//Expression defaultValue = rexArgs.size() == 3
1931-
// ? currentRowTranslator.translate(rexArgs.get(2), res.type)
1932-
// : getDefaultValue(res.type);
1933-
19341930
result.currentBlock().add(Expressions.declare(0, res, null));
1935-
result.currentBlock().add(
1936-
Expressions.ifThenElse(rowInRange, thenBranch,
1937-
Expressions.statement(Expressions.assign(res, defaultValue))));
1931+
if (("java.lang.Double".equalsIgnoreCase(res.getType().getTypeName()) || "double".equalsIgnoreCase(res.getType().getTypeName()))
1932+
&& "java.math.BigDecimal".equalsIgnoreCase(defaultValue.getType().getTypeName())) {
1933+
result.currentBlock().add(
1934+
Expressions.ifThenElse(rowInRange, thenBranch,
1935+
Expressions.statement(Expressions.assign(res, Expressions.call(defaultValue, BuiltInMethod.BIG_DECIMAL_DOUBLEVALUE.method)))));
1936+
} else {
1937+
result.currentBlock().add(
1938+
Expressions.ifThenElse(rowInRange, thenBranch,
1939+
Expressions.statement(Expressions.assign(res, defaultValue))));
1940+
}
19381941
return res;
19391942
}
19401943
}

core/src/main/java/org/apache/calcite/sql/fun/SqlLeadLagAggFunction.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,45 @@
1616
*/
1717
package org.apache.calcite.sql.fun;
1818

19+
import com.google.common.collect.ImmutableList;
1920
import org.apache.calcite.rel.type.RelDataType;
2021
import org.apache.calcite.sql.SqlAggFunction;
2122
import org.apache.calcite.sql.SqlFunctionCategory;
2223
import org.apache.calcite.sql.SqlKind;
2324
import org.apache.calcite.sql.SqlOperatorBinding;
2425
import org.apache.calcite.sql.type.OperandTypes;
2526
import org.apache.calcite.sql.type.ReturnTypes;
27+
import org.apache.calcite.sql.type.SameOperandTypeChecker;
2628
import org.apache.calcite.sql.type.SqlReturnTypeInference;
2729
import org.apache.calcite.sql.type.SqlSingleOperandTypeChecker;
30+
import org.apache.calcite.sql.type.SqlTypeFamily;
2831
import org.apache.calcite.sql.type.SqlTypeTransform;
2932
import org.apache.calcite.sql.type.SqlTypeTransforms;
3033
import org.apache.calcite.util.Optionality;
3134

3235
import com.google.common.base.Preconditions;
3336

37+
import java.util.List;
38+
3439
/**
3540
* <code>LEAD</code> and <code>LAG</code> aggregate functions
3641
* return the value of given expression evaluated at given offset.
3742
*/
3843
public class SqlLeadLagAggFunction extends SqlAggFunction {
3944
private static final SqlSingleOperandTypeChecker OPERAND_TYPES =
40-
OperandTypes.ANY
41-
.or(OperandTypes.ANY_NUMERIC)
42-
.or(OperandTypes.ANY_NUMERIC_ANY
43-
.and(OperandTypes.same(3, 0, 2)));
45+
OperandTypes.or(
46+
OperandTypes.ANY,
47+
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC),
48+
OperandTypes.and(
49+
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.NUMERIC,
50+
SqlTypeFamily.ANY),
51+
// Arguments 1 and 3 must have same type
52+
new SameOperandTypeChecker(3) {
53+
@Override protected List<Integer>
54+
getOperandList(int operandCount) {
55+
return ImmutableList.of(0, 2);
56+
}
57+
}));
4458

4559
private static final SqlReturnTypeInference RETURN_TYPE =
4660
ReturnTypes.ARG0.andThen(SqlLeadLagAggFunction::transformType);

core/src/main/java/org/apache/calcite/util/BuiltInMethod.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ public enum BuiltInMethod {
678678
long.class),
679679
BIG_DECIMAL_ADD(BigDecimal.class, "add", BigDecimal.class),
680680
BIG_DECIMAL_NEGATE(BigDecimal.class, "negate"),
681+
BIG_DECIMAL_DOUBLEVALUE(BigDecimal.class, "doubleValue"),
681682
COMPARE_TO(Comparable.class, "compareTo", Object.class);
682683

683684
@SuppressWarnings("ImmutableEnumChecker")

0 commit comments

Comments
 (0)