Skip to content

Commit 844df7e

Browse files
authored
HIVE-29339: Remove DISTINCT indicator from SqlAggFunctions (#6212)
1. Drop CanAggregateDistinct and refactor dependent code accordingly. 2. Remove isDistinct indicator from all classes extending SqlAggFunction. 3. Move the part handling window functions from SqlFunctionConverter#buildAST to ASTConverter 4. Generalize the generation of TOK_FUNCTIONSTAR for aggregate functions by exploiting SqlOperator#getSqlSyntax 5. Replace CalciteUDAF with SqlBasicAggFunction.create since the former does not bring any additional info (operandTypeInference is removed but it is not used anyways from Hive).
1 parent 1271e84 commit 844df7e

13 files changed

+61
-117
lines changed

ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/HiveCalciteUtil.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ public static AggregateCall createSingleArgAggCall(String funcName, RelOptCluste
954954
PrimitiveTypeInfo typeInfo, Integer pos, RelDataType aggFnRetType) {
955955
ImmutableList.Builder<RelDataType> aggArgRelDTBldr = new ImmutableList.Builder<RelDataType>();
956956
aggArgRelDTBldr.add(TypeConverter.convert(typeInfo, cluster.getTypeFactory()));
957-
SqlAggFunction aggFunction = SqlFunctionConverter.getCalciteAggFn(funcName, false,
957+
SqlAggFunction aggFunction = SqlFunctionConverter.getCalciteAggFn(funcName,
958958
aggArgRelDTBldr.build(), aggFnRetType);
959959
List<Integer> argList = new ArrayList<Integer>();
960960
argList.add(pos);

ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/CanAggregateDistinct.java

Lines changed: 0 additions & 27 deletions
This file was deleted.

ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlAverageAggFunction.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@
2222
import org.apache.calcite.sql.SqlFunctionCategory;
2323
import org.apache.calcite.sql.SqlKind;
2424
import org.apache.calcite.sql.SqlSplittableAggFunction;
25+
import org.apache.calcite.sql.SqlSyntax;
2526
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
2627
import org.apache.calcite.sql.type.SqlOperandTypeInference;
2728
import org.apache.calcite.sql.type.SqlReturnTypeInference;
2829
import org.apache.calcite.util.Optionality;
2930

30-
public class HiveSqlAverageAggFunction extends SqlAggFunction implements CanAggregateDistinct {
31-
private final boolean isDistinct;
31+
public class HiveSqlAverageAggFunction extends SqlAggFunction {
3232

33-
public HiveSqlAverageAggFunction(boolean isDistinct, SqlReturnTypeInference returnTypeInference,
33+
public HiveSqlAverageAggFunction(SqlReturnTypeInference returnTypeInference,
3434
SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker) {
3535
super(
3636
"avg",
@@ -43,7 +43,6 @@ public HiveSqlAverageAggFunction(boolean isDistinct, SqlReturnTypeInference retu
4343
false,
4444
false,
4545
Optionality.FORBIDDEN);
46-
this.isDistinct = isDistinct;
4746
}
4847

4948
@Override
@@ -55,7 +54,7 @@ public <T> T unwrap(Class<T> clazz) {
5554
}
5655

5756
@Override
58-
public boolean isDistinct() {
59-
return isDistinct;
57+
public SqlSyntax getSyntax() {
58+
return SqlSyntax.FUNCTION_STAR;
6059
}
6160
}

ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlCountAggFunction.java

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,12 @@
4141

4242
import com.google.common.collect.ImmutableList;
4343

44-
public class HiveSqlCountAggFunction extends SqlAggFunction implements CanAggregateDistinct {
45-
46-
final boolean isDistinct;
44+
public class HiveSqlCountAggFunction extends SqlAggFunction {
4745
final SqlReturnTypeInference returnTypeInference;
4846
final SqlOperandTypeInference operandTypeInference;
4947
final SqlOperandTypeChecker operandTypeChecker;
5048

51-
public HiveSqlCountAggFunction(boolean isDistinct, SqlReturnTypeInference returnTypeInference,
49+
public HiveSqlCountAggFunction(SqlReturnTypeInference returnTypeInference,
5250
SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker) {
5351
super(
5452
"count",
@@ -57,17 +55,11 @@ public HiveSqlCountAggFunction(boolean isDistinct, SqlReturnTypeInference return
5755
operandTypeInference,
5856
operandTypeChecker,
5957
SqlFunctionCategory.NUMERIC);
60-
this.isDistinct = isDistinct;
6158
this.returnTypeInference = returnTypeInference;
6259
this.operandTypeChecker = operandTypeChecker;
6360
this.operandTypeInference = operandTypeInference;
6461
}
6562

66-
@Override
67-
public boolean isDistinct() {
68-
return isDistinct;
69-
}
70-
7163
@Override
7264
public SqlSyntax getSyntax() {
7365
return SqlSyntax.FUNCTION_STAR;
@@ -91,7 +83,7 @@ class HiveCountSplitter extends CountSplitter {
9183
@Override
9284
public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
9385
return AggregateCall.create(
94-
new HiveSqlCountAggFunction(isDistinct, returnTypeInference, operandTypeInference, operandTypeChecker),
86+
new HiveSqlCountAggFunction(returnTypeInference, operandTypeInference, operandTypeChecker),
9587
false, ImmutableIntList.of(), -1,
9688
typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true), "count");
9789
}
@@ -122,14 +114,14 @@ public AggregateCall topSplit(RexBuilder rexBuilder,
122114
}
123115
int ordinal = extra.register(node);
124116
return AggregateCall.create(
125-
new HiveSqlSumEmptyIsZeroAggFunction(isDistinct, returnTypeInference, operandTypeInference, operandTypeChecker),
117+
new HiveSqlSumEmptyIsZeroAggFunction(returnTypeInference, operandTypeInference, operandTypeChecker),
126118
false, ImmutableList.of(ordinal), -1, aggregateCall.type, aggregateCall.name);
127119
}
128120
}
129121

130122
@Override
131123
public @Nullable SqlAggFunction getRollup() {
132-
return new HiveSqlSumEmptyIsZeroAggFunction(isDistinct(), getReturnTypeInference(), getOperandTypeInference(),
124+
return new HiveSqlSumEmptyIsZeroAggFunction(getReturnTypeInference(), getOperandTypeInference(),
133125
getOperandTypeChecker());
134126
}
135127
}

ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumAggFunction.java

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.calcite.sql.SqlKind;
3131
import org.apache.calcite.sql.SqlSplittableAggFunction;
3232
import org.apache.calcite.sql.SqlSplittableAggFunction.SumSplitter;
33+
import org.apache.calcite.sql.SqlSyntax;
3334
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
3435
import org.apache.calcite.sql.type.ReturnTypes;
3536
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
@@ -46,15 +47,14 @@
4647
* <code>long</code>, <code>float</code>, <code>double</code>), and the result
4748
* is the same type.
4849
*/
49-
public class HiveSqlSumAggFunction extends SqlAggFunction implements CanAggregateDistinct{
50-
final boolean isDistinct;
50+
public class HiveSqlSumAggFunction extends SqlAggFunction {
5151
final SqlReturnTypeInference returnTypeInference;
5252
final SqlOperandTypeInference operandTypeInference;
5353
final SqlOperandTypeChecker operandTypeChecker;
5454

5555
//~ Constructors -----------------------------------------------------------
5656

57-
public HiveSqlSumAggFunction(boolean isDistinct, SqlReturnTypeInference returnTypeInference,
57+
public HiveSqlSumAggFunction(SqlReturnTypeInference returnTypeInference,
5858
SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker) {
5959
super(
6060
"sum",
@@ -66,14 +66,9 @@ public HiveSqlSumAggFunction(boolean isDistinct, SqlReturnTypeInference returnTy
6666
this.returnTypeInference = returnTypeInference;
6767
this.operandTypeChecker = operandTypeChecker;
6868
this.operandTypeInference = operandTypeInference;
69-
this.isDistinct = isDistinct;
7069
}
7170

7271
//~ Methods ----------------------------------------------------------------
73-
@Override
74-
public boolean isDistinct() {
75-
return isDistinct;
76-
}
7772

7873
@Override
7974
public <T> T unwrap(Class<T> clazz) {
@@ -89,7 +84,7 @@ class HiveSumSplitter extends SumSplitter {
8984
public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
9085
RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
9186
return AggregateCall.create(
92-
new HiveSqlCountAggFunction(isDistinct, ReturnTypes.explicit(countRetType), operandTypeInference, operandTypeChecker),
87+
new HiveSqlCountAggFunction(ReturnTypes.explicit(countRetType), operandTypeInference, operandTypeChecker),
9388
false, ImmutableIntList.of(), -1, countRetType, "count");
9489
}
9590

@@ -120,11 +115,26 @@ public AggregateCall topSplit(RexBuilder rexBuilder,
120115
throw new AssertionError("unexpected count " + merges);
121116
}
122117
int ordinal = extra.register(node);
123-
return AggregateCall.create(new HiveSqlSumAggFunction(isDistinct, returnTypeInference, operandTypeInference, operandTypeChecker),
124-
false, ImmutableList.of(ordinal), -1, aggregateCall.type, aggregateCall.name);
118+
return AggregateCall.create(new HiveSqlSumAggFunction(returnTypeInference,
119+
operandTypeInference,
120+
operandTypeChecker),
121+
false,
122+
false,
123+
false,
124+
ImmutableList.of(ordinal),
125+
-1,
126+
aggregateCall.distinctKeys,
127+
aggregateCall.collation,
128+
aggregateCall.type,
129+
aggregateCall.name);
125130
}
126131
}
127132

133+
@Override
134+
public SqlSyntax getSyntax() {
135+
return SqlSyntax.FUNCTION_STAR;
136+
}
137+
128138
@Override
129139
public SqlAggFunction getRollup() {
130140
return this;

ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/functions/HiveSqlSumEmptyIsZeroAggFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
public class HiveSqlSumEmptyIsZeroAggFunction extends SqlAggFunction {
4343
//~ Constructors -----------------------------------------------------------
4444

45-
public HiveSqlSumEmptyIsZeroAggFunction(boolean isDistinct, SqlReturnTypeInference returnTypeInference,
45+
public HiveSqlSumEmptyIsZeroAggFunction(SqlReturnTypeInference returnTypeInference,
4646
SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker) {
4747
super("$SUM0",
4848
null,

ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveAggregateReduceFunctionsRule.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ private RexNode reduceSum0(
300300
final AggregateCall sumCall =
301301
AggregateCall.create(
302302
new HiveSqlSumAggFunction(
303-
oldCall.isDistinct(),
304303
ReturnTypes.explicit(sumReturnType),
305304
oldCall.getAggregation().getOperandTypeInference(),
306305
oldCall.getAggregation().getOperandTypeChecker()), //SqlStdOperatorTable.SUM,
@@ -346,7 +345,6 @@ private RexNode reduceAvg(
346345
final AggregateCall sumCall =
347346
AggregateCall.create(
348347
new HiveSqlSumAggFunction(
349-
oldCall.isDistinct(),
350348
ReturnTypes.explicit(sumReturnType),
351349
oldCall.getAggregation().getOperandTypeInference(),
352350
oldCall.getAggregation().getOperandTypeChecker()), //SqlStdOperatorTable.SUM,
@@ -364,7 +362,6 @@ private RexNode reduceAvg(
364362
final AggregateCall countCall =
365363
AggregateCall.create(
366364
new HiveSqlCountAggFunction(
367-
oldCall.isDistinct(),
368365
ReturnTypes.explicit(countRetType),
369366
oldCall.getAggregation().getOperandTypeInference(),
370367
oldCall.getAggregation().getOperandTypeChecker()), //SqlStdOperatorTable.COUNT,
@@ -445,7 +442,6 @@ private RexNode reduceStddev(
445442
final AggregateCall sumArgSquaredAggCall =
446443
createAggregateCallWithBinding(typeFactory,
447444
new HiveSqlSumAggFunction(
448-
oldCall.isDistinct(),
449445
ReturnTypes.explicit(sumSquaredReturnType),
450446
InferTypes.explicit(Collections.singletonList(argSquared.getType())),
451447
oldCall.getAggregation().getOperandTypeChecker()), //SqlStdOperatorTable.SUM,
@@ -461,7 +457,6 @@ private RexNode reduceStddev(
461457
final AggregateCall sumArgAggCall =
462458
AggregateCall.create(
463459
new HiveSqlSumAggFunction(
464-
oldCall.isDistinct(),
465460
ReturnTypes.explicit(sumReturnType),
466461
InferTypes.explicit(Collections.singletonList(argRef.getType())),
467462
oldCall.getAggregation().getOperandTypeChecker()), //SqlStdOperatorTable.SUM,
@@ -490,7 +485,6 @@ private RexNode reduceStddev(
490485
final AggregateCall countArgAggCall =
491486
AggregateCall.create(
492487
new HiveSqlCountAggFunction(
493-
oldCall.isDistinct(),
494488
ReturnTypes.explicit(countRetType),
495489
oldCall.getAggregation().getOperandTypeInference(),
496490
oldCall.getAggregation().getOperandTypeChecker()), //SqlStdOperatorTable.COUNT,

ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveWindowingLastValueRewrite.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
import org.apache.calcite.rex.RexWindow;
3737
import org.apache.calcite.sql.SqlKind;
3838
import org.apache.calcite.sql.SqlAggFunction;
39+
import org.apache.calcite.sql.fun.SqlBasicAggFunction;
3940
import org.apache.commons.collections4.CollectionUtils;
40-
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.SqlFunctionConverter;
4141

4242
/**
4343
* Rule to rewrite a window function containing a last value clause.
@@ -103,9 +103,8 @@ public RexNode visitOver(RexOver over) {
103103
newOrderKeys.add(new RexFieldCollation(orderKey.left, flags));
104104
}
105105
SqlAggFunction s = (SqlAggFunction) over.op;
106-
SqlFunctionConverter.CalciteUDAF newSqlAggFunction = new SqlFunctionConverter.CalciteUDAF(
107-
over.isDistinct(), FIRST_VALUE_FUNC, s.getReturnTypeInference(), s.getOperandTypeInference(),
108-
s.getOperandTypeChecker());
106+
SqlAggFunction newSqlAggFunction = SqlBasicAggFunction.create(
107+
FIRST_VALUE_FUNC, SqlKind.OTHER_FUNCTION, s.getReturnTypeInference(), s.getOperandTypeChecker());
109108
List<RexNode> clonedOperands = visitList(over.operands, new boolean[] {false});
110109
RexWindow window = visitWindow(over.getWindow());
111110
return rexBuilder.makeOver(over.type, newSqlAggFunction, clonedOperands,

ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/jdbc/JDBCAggregationPushDownRule.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ public boolean matches(RelOptRuleCall call) {
6262
SqlAggFunction f = relOptRuleOperand.getAggregation();
6363
if (f instanceof HiveSqlCountAggFunction) {
6464
//count distinct with more that one argument is not supported
65-
HiveSqlCountAggFunction countAgg = (HiveSqlCountAggFunction)f;
66-
if (countAgg.isDistinct() && 1 < relOptRuleOperand.getArgList().size()) {
65+
if (relOptRuleOperand.isDistinct() && relOptRuleOperand.getArgList().size() > 1) {
6766
return false;
6867
}
6968
}

ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/translator/ASTConverter.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
import org.apache.calcite.rex.RexWindowBound;
6767
import org.apache.calcite.sql.SqlKind;
6868
import org.apache.calcite.sql.SqlOperator;
69+
import org.apache.calcite.sql.SqlSyntax;
6970
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
7071
import org.apache.calcite.sql.type.SqlTypeName;
7172
import org.apache.calcite.util.ImmutableBitSet;
@@ -1000,7 +1001,21 @@ public ASTNode visitOver(RexOver over) {
10001001
}
10011002

10021003
// 1. Translate the UDAF
1003-
final ASTNode wUDAFAst = visitCall(over);
1004+
// Case for the common window function calls: e.g., MIN(salary)
1005+
ASTNode wUDAFAst = ASTBuilder.createAST(HiveParser.TOK_FUNCTION, "TOK_FUNCTION");
1006+
// Case for window function calls with star syntax: e.g., COUNT(*), SUM(*), AVG(*)
1007+
if (over.getOperands().isEmpty() && over.op.getSyntax() == SqlSyntax.FUNCTION_STAR) {
1008+
wUDAFAst = ASTBuilder.createAST(HiveParser.TOK_FUNCTIONSTAR, "TOK_FUNCTIONSTAR");
1009+
}
1010+
// Case for window functions with DISTINCT: e.g., COUNT(DISTINCT deptno)
1011+
if (over.isDistinct()) {
1012+
wUDAFAst = ASTBuilder.createAST(HiveParser.TOK_FUNCTIONDI, "TOK_FUNCTIONDI");
1013+
}
1014+
wUDAFAst.addChild(ASTBuilder.createAST(HiveParser.Identifier, over.op.getName()));
1015+
wUDAFAst.setTypeInfo(TypeConverter.convert(over.type));
1016+
for (RexNode operand : over.getOperands()) {
1017+
wUDAFAst.addChild(operand.accept(this));
1018+
}
10041019

10051020
// 2. Add TOK_WINDOW as child of UDAF
10061021
ASTNode wSpec = ASTBuilder.createAST(HiveParser.TOK_WINDOWSPEC, "TOK_WINDOWSPEC");

0 commit comments

Comments
 (0)