Skip to content

Commit 39291cf

Browse files
yaooqinncloud-fan
authored andcommitted
[SPARK-30048][SQL] Enable aggregates with interval type values for RelationalGroupedDataset
### What changes were proposed in this pull request? Now the min/max/sum/avg are support for intervals, we should also enable it in RelationalGroupedDataset ### Why are the changes needed? API consistency improvement ### Does this PR introduce any user-facing change? yes, Dataset support min/max/sum/avg(mean) on intervals ### How was this patch tested? add ut Closes apache#26681 from yaooqinn/SPARK-30048. Authored-by: Kent Yao <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent d7b268a commit 39291cf

File tree

4 files changed

+45
-24
lines changed

4 files changed

+45
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
7979
private[sql] object TypeCollection {
8080

8181
/**
82-
* Types that include numeric types and interval type. They are only used in unary_minus,
83-
* unary_positive, add and subtract operations.
82+
* Types that include numeric types and interval type, which support numeric type calculations,
83+
* i.e. unary_minus, unary_positive, sum, avg, min, max, add and subtract operations.
8484
*/
8585
val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType)
8686

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,9 +268,9 @@ class Dataset[T] private[sql](
268268
}
269269
}
270270

271-
private[sql] def numericColumns: Seq[Expression] = {
272-
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
273-
queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get
271+
private[sql] def numericCalculationSupportedColumns: Seq[Expression] = {
272+
queryExecution.analyzed.output.filter { attr =>
273+
TypeCollection.NumericAndInterval.acceptsType(attr.dataType)
274274
}
275275
}
276276

sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
3333
import org.apache.spark.sql.catalyst.util.toPrettySQL
3434
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
3535
import org.apache.spark.sql.internal.SQLConf
36-
import org.apache.spark.sql.types.{NumericType, StructType}
36+
import org.apache.spark.sql.types.{StructType, TypeCollection}
3737

3838
/**
3939
* A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]],
@@ -88,20 +88,20 @@ class RelationalGroupedDataset protected[sql](
8888
case expr: Expression => Alias(expr, toPrettySQL(expr))()
8989
}
9090

91-
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
92-
: DataFrame = {
91+
private[this] def aggregateNumericOrIntervalColumns(
92+
colNames: String*)(f: Expression => AggregateFunction): DataFrame = {
9393

9494
val columnExprs = if (colNames.isEmpty) {
95-
// No columns specified. Use all numeric columns.
96-
df.numericColumns
95+
// No columns specified. Use all numeric calculation supported columns.
96+
df.numericCalculationSupportedColumns
9797
} else {
98-
// Make sure all specified columns are numeric.
98+
// Make sure all specified columns are numeric calculation supported columns.
9999
colNames.map { colName =>
100100
val namedExpr = df.resolve(colName)
101-
if (!namedExpr.dataType.isInstanceOf[NumericType]) {
101+
if (!TypeCollection.NumericAndInterval.acceptsType(namedExpr.dataType)) {
102102
throw new AnalysisException(
103-
s""""$colName" is not a numeric column. """ +
104-
"Aggregation function can only be applied on a numeric column.")
103+
s""""$colName" is not a numeric or calendar interval column. """ +
104+
"Aggregation function can only be applied on a numeric or calendar interval column.")
105105
}
106106
namedExpr
107107
}
@@ -269,63 +269,64 @@ class RelationalGroupedDataset protected[sql](
269269
def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")()))
270270

271271
/**
272-
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
272+
* Compute the average value for each numeric or calender interval columns for each group. This
273+
* is an alias for `avg`.
273274
* The resulting `DataFrame` will also contain the grouping columns.
274275
* When specified columns are given, only compute the average values for them.
275276
*
276277
* @since 1.3.0
277278
*/
278279
@scala.annotation.varargs
279280
def mean(colNames: String*): DataFrame = {
280-
aggregateNumericColumns(colNames : _*)(Average)
281+
aggregateNumericOrIntervalColumns(colNames : _*)(Average)
281282
}
282283

283284
/**
284-
* Compute the max value for each numeric columns for each group.
285+
* Compute the max value for each numeric calender interval columns for each group.
285286
* The resulting `DataFrame` will also contain the grouping columns.
286287
* When specified columns are given, only compute the max values for them.
287288
*
288289
* @since 1.3.0
289290
*/
290291
@scala.annotation.varargs
291292
def max(colNames: String*): DataFrame = {
292-
aggregateNumericColumns(colNames : _*)(Max)
293+
aggregateNumericOrIntervalColumns(colNames : _*)(Max)
293294
}
294295

295296
/**
296-
* Compute the mean value for each numeric columns for each group.
297+
* Compute the mean value for each numeric calender interval columns for each group.
297298
* The resulting `DataFrame` will also contain the grouping columns.
298299
* When specified columns are given, only compute the mean values for them.
299300
*
300301
* @since 1.3.0
301302
*/
302303
@scala.annotation.varargs
303304
def avg(colNames: String*): DataFrame = {
304-
aggregateNumericColumns(colNames : _*)(Average)
305+
aggregateNumericOrIntervalColumns(colNames : _*)(Average)
305306
}
306307

307308
/**
308-
* Compute the min value for each numeric column for each group.
309+
* Compute the min value for each numeric calender interval column for each group.
309310
* The resulting `DataFrame` will also contain the grouping columns.
310311
* When specified columns are given, only compute the min values for them.
311312
*
312313
* @since 1.3.0
313314
*/
314315
@scala.annotation.varargs
315316
def min(colNames: String*): DataFrame = {
316-
aggregateNumericColumns(colNames : _*)(Min)
317+
aggregateNumericOrIntervalColumns(colNames : _*)(Min)
317318
}
318319

319320
/**
320-
* Compute the sum for each numeric columns for each group.
321+
* Compute the sum for each numeric calender interval columns for each group.
321322
* The resulting `DataFrame` will also contain the grouping columns.
322323
* When specified columns are given, only compute the sum for them.
323324
*
324325
* @since 1.3.0
325326
*/
326327
@scala.annotation.varargs
327328
def sum(colNames: String*): DataFrame = {
328-
aggregateNumericColumns(colNames : _*)(Sum)
329+
aggregateNumericOrIntervalColumns(colNames : _*)(Sum)
329330
}
330331

331332
/**

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,4 +964,24 @@ class DataFrameAggregateSuite extends QueryTest with SharedSparkSession {
964964
Row(3, new CalendarInterval(0, 3, 0)) :: Nil)
965965
assert(df3.queryExecution.executedPlan.find(_.isInstanceOf[HashAggregateExec]).isDefined)
966966
}
967+
968+
test("Dataset agg functions support calendar intervals") {
969+
val df1 = Seq((1, "1 day"), (2, "2 day"), (3, "3 day"), (3, null)).toDF("a", "b")
970+
val df2 = df1.select('a, 'b cast CalendarIntervalType).groupBy('a % 2)
971+
checkAnswer(df2.sum("b"),
972+
Row(0, new CalendarInterval(0, 2, 0)) ::
973+
Row(1, new CalendarInterval(0, 4, 0)) :: Nil)
974+
checkAnswer(df2.avg("b"),
975+
Row(0, new CalendarInterval(0, 2, 0)) ::
976+
Row(1, new CalendarInterval(0, 2, 0)) :: Nil)
977+
checkAnswer(df2.mean("b"),
978+
Row(0, new CalendarInterval(0, 2, 0)) ::
979+
Row(1, new CalendarInterval(0, 2, 0)) :: Nil)
980+
checkAnswer(df2.max("b"),
981+
Row(0, new CalendarInterval(0, 2, 0)) ::
982+
Row(1, new CalendarInterval(0, 3, 0)) :: Nil)
983+
checkAnswer(df2.min("b"),
984+
Row(0, new CalendarInterval(0, 2, 0)) ::
985+
Row(1, new CalendarInterval(0, 1, 0)) :: Nil)
986+
}
967987
}

0 commit comments

Comments
 (0)