@@ -124,7 +124,25 @@ public void onMatch(RelOptRuleCall ruleCall) {
124124 */
125125 private boolean containsAvgStddevVarCall (List <AggregateCall > aggCallList ) {
126126 for (AggregateCall call : aggCallList ) {
127+ // Check the aggregate function name directly
128+ String aggName = call .getAggregation ().getName ();
129+ if (aggName .equalsIgnoreCase ("AVG" ) ||
130+ aggName .equalsIgnoreCase ("STDDEV_POP" ) || aggName .equalsIgnoreCase ("STDDEV_SAMP" ) ||
131+ aggName .equalsIgnoreCase ("VAR_POP" ) || aggName .equalsIgnoreCase ("VAR_SAMP" ) ||
132+ aggName .equalsIgnoreCase ("SUM" ) || aggName .equalsIgnoreCase ("SUM0" ) ||
133+ aggName .equalsIgnoreCase ("$SUM0" )) {
134+ return true ;
135+ }
136+
137+ // Fallback: check by SqlKind and instanceof for standard Calcite functions
127138 SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility .extractSqlOperatorFromWrapper (call .getAggregation ());
139+ SqlKind kind = sqlAggFunction .getKind ();
140+ if (kind == SqlKind .AVG ||
141+ kind == SqlKind .STDDEV_POP || kind == SqlKind .STDDEV_SAMP ||
142+ kind == SqlKind .VAR_POP || kind == SqlKind .VAR_SAMP ||
143+ kind == SqlKind .SUM || kind == SqlKind .SUM0 ) {
144+ return true ;
145+ }
128146 if (sqlAggFunction instanceof SqlAvgAggFunction
129147 || sqlAggFunction instanceof SqlSumAggFunction ) {
130148 return true ;
@@ -229,16 +247,48 @@ private RexNode reduceAgg(
229247 Map <AggregateCall , RexNode > aggCallMapping ,
230248 List <RexNode > inputExprs ) {
231249 final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility .extractSqlOperatorFromWrapper (oldCall .getAggregation ());
232- if (sqlAggFunction instanceof SqlSumAggFunction ) {
250+ final SqlKind sqlKind = sqlAggFunction .getKind ();
251+
252+ // Handle SUM
253+ if (sqlKind == SqlKind .SUM || sqlKind == SqlKind .SUM0 ||
254+ sqlAggFunction instanceof SqlSumAggFunction ) {
233255 // replace original SUM(x) with
234256 // case COUNT(x) when 0 then null else SUM0(x) end
235257 return reduceSum (oldAggRel , oldCall , newCalls , aggCallMapping );
236258 }
237- if (sqlAggFunction instanceof SqlAvgAggFunction ) {
238- // for DECIMAL data types does not produce rewriting of complex calls,
239- // since SUM returns value with 38 precision and further handling of the value
240- // causes the loss of the scale
241- if (oldCall .getType ().getSqlTypeName () == SqlTypeName .DECIMAL ) {
259+
260+ // Handle AVG, VAR_*, STDDEV_* - check by SqlKind or by name for Drill-wrapped functions
261+ String aggName = oldCall .getAggregation ().getName ();
262+ boolean isVarianceOrAvg = (sqlKind == SqlKind .AVG || sqlKind == SqlKind .STDDEV_POP || sqlKind == SqlKind .STDDEV_SAMP ||
263+ sqlKind == SqlKind .VAR_POP || sqlKind == SqlKind .VAR_SAMP ||
264+ sqlAggFunction instanceof SqlAvgAggFunction ||
265+ aggName .equalsIgnoreCase ("AVG" ) || aggName .equalsIgnoreCase ("VAR_POP" ) ||
266+ aggName .equalsIgnoreCase ("VAR_SAMP" ) || aggName .equalsIgnoreCase ("STDDEV_POP" ) ||
267+ aggName .equalsIgnoreCase ("STDDEV_SAMP" ));
268+ if (isVarianceOrAvg ) {
269+
270+ // Determine the subtype from name if SqlKind is OTHER_FUNCTION (Drill-wrapped)
271+ SqlKind subtype = sqlKind ;
272+ if (sqlKind == SqlKind .OTHER_FUNCTION || sqlKind == SqlKind .OTHER ) {
273+ // Use aggName already declared above
274+ if (aggName .equalsIgnoreCase ("AVG" )) {
275+ subtype = SqlKind .AVG ;
276+ } else if (aggName .equalsIgnoreCase ("VAR_POP" )) {
277+ subtype = SqlKind .VAR_POP ;
278+ } else if (aggName .equalsIgnoreCase ("VAR_SAMP" )) {
279+ subtype = SqlKind .VAR_SAMP ;
280+ } else if (aggName .equalsIgnoreCase ("STDDEV_POP" )) {
281+ subtype = SqlKind .STDDEV_POP ;
282+ } else if (aggName .equalsIgnoreCase ("STDDEV_SAMP" )) {
283+ subtype = SqlKind .STDDEV_SAMP ;
284+ }
285+ }
286+
287+ // For DECIMAL data types, only skip reduction for AVG (not for VAR_*/STDDEV_*)
288+ // AVG reduction causes loss of scale, but variance/stddev MUST be reduced
289+ // to avoid Calcite 1.38 CALCITE-6427 bug that creates invalid DECIMAL types
290+ if (oldCall .getType ().getSqlTypeName () == SqlTypeName .DECIMAL &&
291+ subtype == SqlKind .AVG ) {
242292 return oldAggRel .getCluster ().getRexBuilder ().addAggCall (
243293 oldCall ,
244294 oldAggRel .getGroupCount (),
@@ -248,7 +298,6 @@ private RexNode reduceAgg(
248298 oldAggRel .getInput (),
249299 oldCall .getArgList ().get (0 ))));
250300 }
251- final SqlKind subtype = sqlAggFunction .getKind ();
252301 switch (subtype ) {
253302 case AVG :
254303 // replace original AVG(x) with SUM(x) / COUNT(x)
@@ -526,16 +575,29 @@ private RexNode reduceStddev(
526575 RexNode argRef = rexBuilder .makeCall (CastHighOp , inputExprs .get (argOrdinal ));
527576 inputExprs .set (argOrdinal , argRef );
528577
529- final RexNode argSquared =
578+ // Create argSquared (x * x) and fix its type if invalid
579+ RexNode argSquared =
530580 rexBuilder .makeCall (
531581 SqlStdOperatorTable .MULTIPLY , argRef , argRef );
582+
583+ // Fix DECIMAL type if Calcite 1.38 created invalid type (scale > precision)
584+ RelDataType argSquaredType = fixDecimalType (typeFactory , argSquared .getType ());
585+ if (!argSquaredType .equals (argSquared .getType ())) {
586+ // Recreate the call with the fixed type
587+ argSquared = rexBuilder .makeCall (argSquaredType , SqlStdOperatorTable .MULTIPLY ,
588+ java .util .Arrays .asList (argRef , argRef ));
589+ }
590+
532591 final int argSquaredOrdinal = lookupOrAdd (inputExprs , argSquared );
533592
534593 RelDataType sumType =
535594 TypeInferenceUtils .getDrillSqlReturnTypeInference (SqlKind .SUM .name (),
536595 ImmutableList .of ())
537596 .inferReturnType (oldCall .createBinding (oldAggRel ));
538597 sumType = typeFactory .createTypeWithNullability (sumType , true );
598+
599+ // Fix sumType if Calcite 1.38 created invalid DECIMAL type (scale > precision)
600+ sumType = fixDecimalType (typeFactory , sumType );
539601 final AggregateCall sumArgSquaredAggCall =
540602 AggregateCall .create (
541603 new DrillCalciteSqlAggFunctionWrapper (
@@ -580,10 +642,19 @@ private RexNode reduceStddev(
580642 aggCallMapping ,
581643 ImmutableList .of (argType ));
582644
583- final RexNode sumSquaredArg =
645+ // Create sumSquaredArg (SUM(x) * SUM(x)) and fix its type if invalid
646+ RexNode sumSquaredArg =
584647 rexBuilder .makeCall (
585648 SqlStdOperatorTable .MULTIPLY , sumArg , sumArg );
586649
650+ // Fix DECIMAL type if Calcite 1.38 created invalid type (scale > precision)
651+ RelDataType sumSquaredArgType = fixDecimalType (typeFactory , sumSquaredArg .getType ());
652+ if (!sumSquaredArgType .equals (sumSquaredArg .getType ())) {
653+ // Recreate the call with the fixed type
654+ sumSquaredArg = rexBuilder .makeCall (sumSquaredArgType , SqlStdOperatorTable .MULTIPLY ,
655+ java .util .Arrays .asList (sumArg , sumArg ));
656+ }
657+
587658 final SqlCountAggFunction countAgg = (SqlCountAggFunction ) SqlStdOperatorTable .COUNT ;
588659 final RelDataType countType = countAgg .getReturnType (typeFactory );
589660 final AggregateCall countArgAggCall = getAggCall (oldCall , countAgg , countType );
@@ -682,6 +753,44 @@ private static <T> int lookupOrAdd(List<T> list, T element) {
682753 return ordinal ;
683754 }
684755
756+ /**
757+ * Fix invalid DECIMAL types where scale > precision.
758+ * This can happen with Calcite 1.38 CALCITE-6427 where variance functions
759+ * use DECIMAL(2*p, 2*s) for intermediate calculations.
760+ *
761+ * @param typeFactory Type factory to create corrected types
762+ * @param type Type to check and potentially fix
763+ * @return Fixed type if invalid, original type otherwise
764+ */
765+ private static RelDataType fixDecimalType (RelDataTypeFactory typeFactory , RelDataType type ) {
766+ if (type .getSqlTypeName () != SqlTypeName .DECIMAL ) {
767+ return type ;
768+ }
769+
770+ int precision = type .getPrecision ();
771+ int scale = type .getScale ();
772+
773+ // Check if type is invalid (scale > precision)
774+ if (scale <= precision && precision <= 38 ) {
775+ return type ; // Type is valid
776+ }
777+
778+ // Fix the type
779+ int maxPrecision = 38 ; // Drill's maximum DECIMAL precision
780+
781+ // First, cap precision at Drill's max
782+ if (precision > maxPrecision ) {
783+ precision = maxPrecision ;
784+ }
785+
786+ // Then ensure scale doesn't exceed precision
787+ if (scale > precision ) {
788+ scale = precision ;
789+ }
790+
791+ return typeFactory .createSqlType (SqlTypeName .DECIMAL , precision , scale );
792+ }
793+
685794 /**
686795 * Do a shallow clone of oldAggRel and update aggCalls. Could be refactored
687796 * into Aggregate and subclasses - but it's only needed for some
0 commit comments