Skip to content

Commit 5a8f92d

Browse files
committed
WIP: DECIMAL issues fixed
1 parent 8b6c2ce commit 5a8f92d

File tree

5 files changed

+375
-23
lines changed

5 files changed

+375
-23
lines changed

exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillReduceAggregatesRule.java

Lines changed: 118 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,25 @@ public void onMatch(RelOptRuleCall ruleCall) {
124124
*/
125125
private boolean containsAvgStddevVarCall(List<AggregateCall> aggCallList) {
126126
for (AggregateCall call : aggCallList) {
127+
// Check the aggregate function name directly
128+
String aggName = call.getAggregation().getName();
129+
if (aggName.equalsIgnoreCase("AVG") ||
130+
aggName.equalsIgnoreCase("STDDEV_POP") || aggName.equalsIgnoreCase("STDDEV_SAMP") ||
131+
aggName.equalsIgnoreCase("VAR_POP") || aggName.equalsIgnoreCase("VAR_SAMP") ||
132+
aggName.equalsIgnoreCase("SUM") || aggName.equalsIgnoreCase("SUM0") ||
133+
aggName.equalsIgnoreCase("$SUM0")) {
134+
return true;
135+
}
136+
137+
// Fallback: check by SqlKind and instanceof for standard Calcite functions
127138
SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(call.getAggregation());
139+
SqlKind kind = sqlAggFunction.getKind();
140+
if (kind == SqlKind.AVG ||
141+
kind == SqlKind.STDDEV_POP || kind == SqlKind.STDDEV_SAMP ||
142+
kind == SqlKind.VAR_POP || kind == SqlKind.VAR_SAMP ||
143+
kind == SqlKind.SUM || kind == SqlKind.SUM0) {
144+
return true;
145+
}
128146
if (sqlAggFunction instanceof SqlAvgAggFunction
129147
|| sqlAggFunction instanceof SqlSumAggFunction) {
130148
return true;
@@ -229,16 +247,48 @@ private RexNode reduceAgg(
229247
Map<AggregateCall, RexNode> aggCallMapping,
230248
List<RexNode> inputExprs) {
231249
final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(oldCall.getAggregation());
232-
if (sqlAggFunction instanceof SqlSumAggFunction) {
250+
final SqlKind sqlKind = sqlAggFunction.getKind();
251+
252+
// Handle SUM
253+
if (sqlKind == SqlKind.SUM || sqlKind == SqlKind.SUM0 ||
254+
sqlAggFunction instanceof SqlSumAggFunction) {
233255
// replace original SUM(x) with
234256
// case COUNT(x) when 0 then null else SUM0(x) end
235257
return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
236258
}
237-
if (sqlAggFunction instanceof SqlAvgAggFunction) {
238-
// for DECIMAL data types does not produce rewriting of complex calls,
239-
// since SUM returns value with 38 precision and further handling of the value
240-
// causes the loss of the scale
241-
if (oldCall.getType().getSqlTypeName() == SqlTypeName.DECIMAL) {
259+
260+
// Handle AVG, VAR_*, STDDEV_* - check by SqlKind or by name for Drill-wrapped functions
261+
String aggName = oldCall.getAggregation().getName();
262+
boolean isVarianceOrAvg = (sqlKind == SqlKind.AVG || sqlKind == SqlKind.STDDEV_POP || sqlKind == SqlKind.STDDEV_SAMP ||
263+
sqlKind == SqlKind.VAR_POP || sqlKind == SqlKind.VAR_SAMP ||
264+
sqlAggFunction instanceof SqlAvgAggFunction ||
265+
aggName.equalsIgnoreCase("AVG") || aggName.equalsIgnoreCase("VAR_POP") ||
266+
aggName.equalsIgnoreCase("VAR_SAMP") || aggName.equalsIgnoreCase("STDDEV_POP") ||
267+
aggName.equalsIgnoreCase("STDDEV_SAMP"));
268+
if (isVarianceOrAvg) {
269+
270+
// Determine the subtype from name if SqlKind is OTHER_FUNCTION (Drill-wrapped)
271+
SqlKind subtype = sqlKind;
272+
if (sqlKind == SqlKind.OTHER_FUNCTION || sqlKind == SqlKind.OTHER) {
273+
// Use aggName already declared above
274+
if (aggName.equalsIgnoreCase("AVG")) {
275+
subtype = SqlKind.AVG;
276+
} else if (aggName.equalsIgnoreCase("VAR_POP")) {
277+
subtype = SqlKind.VAR_POP;
278+
} else if (aggName.equalsIgnoreCase("VAR_SAMP")) {
279+
subtype = SqlKind.VAR_SAMP;
280+
} else if (aggName.equalsIgnoreCase("STDDEV_POP")) {
281+
subtype = SqlKind.STDDEV_POP;
282+
} else if (aggName.equalsIgnoreCase("STDDEV_SAMP")) {
283+
subtype = SqlKind.STDDEV_SAMP;
284+
}
285+
}
286+
287+
// For DECIMAL data types, only skip reduction for AVG (not for VAR_*/STDDEV_*)
288+
// AVG reduction causes loss of scale, but variance/stddev MUST be reduced
289+
// to avoid Calcite 1.38 CALCITE-6427 bug that creates invalid DECIMAL types
290+
if (oldCall.getType().getSqlTypeName() == SqlTypeName.DECIMAL &&
291+
subtype == SqlKind.AVG) {
242292
return oldAggRel.getCluster().getRexBuilder().addAggCall(
243293
oldCall,
244294
oldAggRel.getGroupCount(),
@@ -248,7 +298,6 @@ private RexNode reduceAgg(
248298
oldAggRel.getInput(),
249299
oldCall.getArgList().get(0))));
250300
}
251-
final SqlKind subtype = sqlAggFunction.getKind();
252301
switch (subtype) {
253302
case AVG:
254303
// replace original AVG(x) with SUM(x) / COUNT(x)
@@ -526,16 +575,29 @@ private RexNode reduceStddev(
526575
RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal));
527576
inputExprs.set(argOrdinal, argRef);
528577

529-
final RexNode argSquared =
578+
// Create argSquared (x * x) and fix its type if invalid
579+
RexNode argSquared =
530580
rexBuilder.makeCall(
531581
SqlStdOperatorTable.MULTIPLY, argRef, argRef);
582+
583+
// Fix DECIMAL type if Calcite 1.38 created invalid type (scale > precision)
584+
RelDataType argSquaredType = fixDecimalType(typeFactory, argSquared.getType());
585+
if (!argSquaredType.equals(argSquared.getType())) {
586+
// Recreate the call with the fixed type
587+
argSquared = rexBuilder.makeCall(argSquaredType, SqlStdOperatorTable.MULTIPLY,
588+
java.util.Arrays.asList(argRef, argRef));
589+
}
590+
532591
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
533592

534593
RelDataType sumType =
535594
TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(),
536595
ImmutableList.of())
537596
.inferReturnType(oldCall.createBinding(oldAggRel));
538597
sumType = typeFactory.createTypeWithNullability(sumType, true);
598+
599+
// Fix sumType if Calcite 1.38 created invalid DECIMAL type (scale > precision)
600+
sumType = fixDecimalType(typeFactory, sumType);
539601
final AggregateCall sumArgSquaredAggCall =
540602
AggregateCall.create(
541603
new DrillCalciteSqlAggFunctionWrapper(
@@ -580,10 +642,19 @@ private RexNode reduceStddev(
580642
aggCallMapping,
581643
ImmutableList.of(argType));
582644

583-
final RexNode sumSquaredArg =
645+
// Create sumSquaredArg (SUM(x) * SUM(x)) and fix its type if invalid
646+
RexNode sumSquaredArg =
584647
rexBuilder.makeCall(
585648
SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
586649

650+
// Fix DECIMAL type if Calcite 1.38 created invalid type (scale > precision)
651+
RelDataType sumSquaredArgType = fixDecimalType(typeFactory, sumSquaredArg.getType());
652+
if (!sumSquaredArgType.equals(sumSquaredArg.getType())) {
653+
// Recreate the call with the fixed type
654+
sumSquaredArg = rexBuilder.makeCall(sumSquaredArgType, SqlStdOperatorTable.MULTIPLY,
655+
java.util.Arrays.asList(sumArg, sumArg));
656+
}
657+
587658
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
588659
final RelDataType countType = countAgg.getReturnType(typeFactory);
589660
final AggregateCall countArgAggCall = getAggCall(oldCall, countAgg, countType);
@@ -682,6 +753,44 @@ private static <T> int lookupOrAdd(List<T> list, T element) {
682753
return ordinal;
683754
}
684755

756+
/**
757+
* Fix invalid DECIMAL types where scale > precision.
758+
* This can happen with Calcite 1.38 CALCITE-6427 where variance functions
759+
* use DECIMAL(2*p, 2*s) for intermediate calculations.
760+
*
761+
* @param typeFactory Type factory to create corrected types
762+
* @param type Type to check and potentially fix
763+
* @return Fixed type if invalid, original type otherwise
764+
*/
765+
private static RelDataType fixDecimalType(RelDataTypeFactory typeFactory, RelDataType type) {
766+
if (type.getSqlTypeName() != SqlTypeName.DECIMAL) {
767+
return type;
768+
}
769+
770+
int precision = type.getPrecision();
771+
int scale = type.getScale();
772+
773+
// Check if type is invalid (scale > precision)
774+
if (scale <= precision && precision <= 38) {
775+
return type; // Type is valid
776+
}
777+
778+
// Fix the type
779+
int maxPrecision = 38; // Drill's maximum DECIMAL precision
780+
781+
// First, cap precision at Drill's max
782+
if (precision > maxPrecision) {
783+
precision = maxPrecision;
784+
}
785+
786+
// Then ensure scale doesn't exceed precision
787+
if (scale > precision) {
788+
scale = precision;
789+
}
790+
791+
return typeFactory.createSqlType(SqlTypeName.DECIMAL, precision, scale);
792+
}
793+
685794
/**
686795
* Do a shallow clone of oldAggRel and update aggCalls. Could be refactored
687796
* into Aggregate and subclasses - but it's only needed for some

exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/TypeInferenceUtils.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -895,13 +895,20 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
895895
isNullable
896896
);
897897
case VARDECIMAL:
898+
// For Calcite 1.38+ compatibility: Variance/stddev functions use double precision/scale
899+
// internally (CALCITE-6427), which can exceed Drill's DECIMAL(38,38) limit.
900+
// We need to ensure scale doesn't exceed precision.
901+
int maxPrecision = DrillRelDataTypeSystem.DRILL_REL_DATATYPE_SYSTEM.getMaxNumericPrecision();
902+
int maxScale = DrillRelDataTypeSystem.DRILL_REL_DATATYPE_SYSTEM.getMaxNumericScale();
903+
int desiredScale = Math.max(6, operandType.getScale());
904+
905+
// Ensure scale doesn't exceed maxPrecision (invalid DECIMAL type)
906+
int finalScale = Math.min(desiredScale, Math.min(maxScale, maxPrecision));
907+
898908
RelDataType sqlType = factory.createSqlType(
899909
SqlTypeName.DECIMAL,
900-
DrillRelDataTypeSystem.DRILL_REL_DATATYPE_SYSTEM.getMaxNumericPrecision(),
901-
Math.min(
902-
Math.max(6, operandType.getScale()),
903-
DrillRelDataTypeSystem.DRILL_REL_DATATYPE_SYSTEM.getMaxNumericScale()
904-
)
910+
maxPrecision,
911+
finalScale
905912
);
906913
return factory.createTypeWithNullability(sqlType, isNullable);
907914
default:

exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/conversion/DrillRexBuilder.java

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818
package org.apache.drill.exec.planner.sql.conversion;
1919

2020
import java.math.BigDecimal;
21+
import java.util.List;
2122

2223
import org.apache.calcite.rel.type.RelDataType;
2324
import org.apache.calcite.rel.type.RelDataTypeFactory;
2425
import org.apache.calcite.rex.RexBuilder;
2526
import org.apache.calcite.rex.RexLiteral;
2627
import org.apache.calcite.rex.RexNode;
28+
import org.apache.calcite.sql.SqlOperator;
29+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
2730
import org.apache.calcite.sql.type.SqlTypeName;
2831
import org.apache.drill.common.exceptions.UserException;
2932
import org.apache.drill.exec.util.DecimalUtility;
@@ -38,6 +41,97 @@ class DrillRexBuilder extends RexBuilder {
3841
super(typeFactory);
3942
}
4043

44+
/**
45+
* Override makeCall to fix DECIMAL precision/scale issues in Calcite 1.38.
46+
* CALCITE-6427 can create invalid DECIMAL types where scale > precision.
47+
* This version intercepts calls WITH explicit return type.
48+
*/
49+
@Override
50+
public RexNode makeCall(RelDataType returnType, SqlOperator op, List<RexNode> exprs) {
51+
// Fix DECIMAL return types for arithmetic operations
52+
if (returnType.getSqlTypeName() == SqlTypeName.DECIMAL) {
53+
int precision = returnType.getPrecision();
54+
int scale = returnType.getScale();
55+
56+
// If scale exceeds precision, fix it
57+
if (scale > precision) {
58+
System.out.println("DrillRexBuilder.makeCall(with type): fixing invalid DECIMAL type for " + op.getName() +
59+
": precision=" + precision + ", scale=" + scale);
60+
61+
// Cap precision at Drill's max (38)
62+
int maxPrecision = 38;
63+
if (precision > maxPrecision) {
64+
precision = maxPrecision;
65+
}
66+
67+
// Ensure scale doesn't exceed precision
68+
if (scale > precision) {
69+
scale = precision;
70+
}
71+
72+
System.out.println("DrillRexBuilder.makeCall(with type): corrected to precision=" + precision + ", scale=" + scale);
73+
74+
// Create corrected type
75+
returnType = typeFactory.createSqlType(SqlTypeName.DECIMAL, precision, scale);
76+
}
77+
}
78+
79+
return super.makeCall(returnType, op, exprs);
80+
}
81+
82+
/**
83+
* Override makeCall to fix DECIMAL precision/scale issues in Calcite 1.38.
84+
* CALCITE-6427 can create invalid DECIMAL types where scale > precision.
85+
* This version intercepts calls WITHOUT explicit return type (type is inferred).
86+
* NOTE: Cannot override makeCall(SqlOperator, RexNode...) because it's final in RexBuilder.
87+
* Instead, override the List version which the varargs version calls internally.
88+
*/
89+
@Override
90+
public RexNode makeCall(SqlOperator op, List<? extends RexNode> exprs) {
91+
System.out.println("DrillRexBuilder.makeCall(no type): op=" + op.getName() + ", exprs=" + exprs.size());
92+
93+
// Call super to get the result with inferred type
94+
RexNode result = super.makeCall(op, exprs);
95+
96+
// Check if the inferred type has invalid DECIMAL precision/scale
97+
if (result.getType().getSqlTypeName() == SqlTypeName.DECIMAL) {
98+
int precision = result.getType().getPrecision();
99+
int scale = result.getType().getScale();
100+
101+
System.out.println("DrillRexBuilder.makeCall(no type): inferred DECIMAL type: precision=" + precision + ", scale=" + scale);
102+
103+
// If scale exceeds precision, recreate the call with fixed type
104+
if (scale > precision) {
105+
System.out.println("DrillRexBuilder.makeCall(no type): fixing invalid DECIMAL type for " + op.getName() +
106+
": precision=" + precision + ", scale=" + scale);
107+
108+
// Cap precision at Drill's max (38)
109+
int maxPrecision = 38;
110+
if (precision > maxPrecision) {
111+
precision = maxPrecision;
112+
}
113+
114+
// Ensure scale doesn't exceed precision
115+
if (scale > precision) {
116+
scale = precision;
117+
}
118+
119+
System.out.println("DrillRexBuilder.makeCall(no type): corrected to precision=" + precision + ", scale=" + scale);
120+
121+
// Create corrected type and recreate the call with fixed type
122+
RelDataType fixedType = typeFactory.createSqlType(SqlTypeName.DECIMAL, precision, scale);
123+
// Convert to List<RexNode> to call the 3-arg version with explicit type
124+
List<RexNode> exprList = new java.util.ArrayList<>();
125+
for (RexNode expr : exprs) {
126+
exprList.add(expr);
127+
}
128+
result = super.makeCall(fixedType, op, exprList);
129+
}
130+
}
131+
132+
return result;
133+
}
134+
41135
/**
42136
* Since Drill has different mechanism and rules for implicit casting,
43137
* ensureType() is overridden to avoid conflicting cast functions being added to the expressions.

0 commit comments

Comments
 (0)