@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
33
33
import org .apache .spark .sql .catalyst .util .toPrettySQL
34
34
import org .apache .spark .sql .execution .aggregate .TypedAggregateExpression
35
35
import org .apache .spark .sql .internal .SQLConf
36
- import org .apache .spark .sql .types .{NumericType , StructType }
36
+ import org .apache .spark .sql .types .{StructType , TypeCollection }
37
37
38
38
/**
39
39
* A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy ]],
@@ -88,20 +88,20 @@ class RelationalGroupedDataset protected[sql](
88
88
case expr : Expression => Alias (expr, toPrettySQL(expr))()
89
89
}
90
90
91
- private [this ] def aggregateNumericColumns ( colNames : String * )( f : Expression => AggregateFunction )
92
- : DataFrame = {
91
+ private [this ] def aggregateNumericOrIntervalColumns (
92
+ colNames : String * )( f : Expression => AggregateFunction ) : DataFrame = {
93
93
94
94
val columnExprs = if (colNames.isEmpty) {
95
- // No columns specified. Use all numeric columns.
96
- df.numericColumns
95
+ // No columns specified. Use all numeric calculation supported columns.
96
+ df.numericCalculationSupportedColumns
97
97
} else {
98
- // Make sure all specified columns are numeric.
98
+ // Make sure all specified columns are numeric calculation supported columns .
99
99
colNames.map { colName =>
100
100
val namedExpr = df.resolve(colName)
101
- if (! namedExpr.dataType. isInstanceOf [ NumericType ] ) {
101
+ if (! TypeCollection . NumericAndInterval .acceptsType( namedExpr.dataType) ) {
102
102
throw new AnalysisException (
103
- s """ " $colName" is not a numeric column. """ +
104
- " Aggregation function can only be applied on a numeric column." )
103
+ s """ " $colName" is not a numeric or calendar interval column. """ +
104
+ " Aggregation function can only be applied on a numeric or calendar interval column." )
105
105
}
106
106
namedExpr
107
107
}
@@ -269,63 +269,64 @@ class RelationalGroupedDataset protected[sql](
269
269
def count (): DataFrame = toDF(Seq (Alias (Count (Literal (1 )).toAggregateExpression(), " count" )()))
270
270
271
271
/**
272
- * Compute the average value for each numeric columns for each group. This is an alias for `avg`.
272
+ * Compute the average value for each numeric or calender interval columns for each group. This
273
+ * is an alias for `avg`.
273
274
* The resulting `DataFrame` will also contain the grouping columns.
274
275
* When specified columns are given, only compute the average values for them.
275
276
*
276
277
* @since 1.3.0
277
278
*/
278
279
@ scala.annotation.varargs
279
280
def mean (colNames : String * ): DataFrame = {
280
- aggregateNumericColumns (colNames : _* )(Average )
281
+ aggregateNumericOrIntervalColumns (colNames : _* )(Average )
281
282
}
282
283
283
284
/**
284
- * Compute the max value for each numeric columns for each group.
285
+ * Compute the max value for each numeric calender interval columns for each group.
285
286
* The resulting `DataFrame` will also contain the grouping columns.
286
287
* When specified columns are given, only compute the max values for them.
287
288
*
288
289
* @since 1.3.0
289
290
*/
290
291
@ scala.annotation.varargs
291
292
def max (colNames : String * ): DataFrame = {
292
- aggregateNumericColumns (colNames : _* )(Max )
293
+ aggregateNumericOrIntervalColumns (colNames : _* )(Max )
293
294
}
294
295
295
296
/**
296
- * Compute the mean value for each numeric columns for each group.
297
+ * Compute the mean value for each numeric calender interval columns for each group.
297
298
* The resulting `DataFrame` will also contain the grouping columns.
298
299
* When specified columns are given, only compute the mean values for them.
299
300
*
300
301
* @since 1.3.0
301
302
*/
302
303
@ scala.annotation.varargs
303
304
def avg (colNames : String * ): DataFrame = {
304
- aggregateNumericColumns (colNames : _* )(Average )
305
+ aggregateNumericOrIntervalColumns (colNames : _* )(Average )
305
306
}
306
307
307
308
/**
308
- * Compute the min value for each numeric column for each group.
309
+ * Compute the min value for each numeric calender interval column for each group.
309
310
* The resulting `DataFrame` will also contain the grouping columns.
310
311
* When specified columns are given, only compute the min values for them.
311
312
*
312
313
* @since 1.3.0
313
314
*/
314
315
@ scala.annotation.varargs
315
316
def min (colNames : String * ): DataFrame = {
316
- aggregateNumericColumns (colNames : _* )(Min )
317
+ aggregateNumericOrIntervalColumns (colNames : _* )(Min )
317
318
}
318
319
319
320
/**
320
- * Compute the sum for each numeric columns for each group.
321
+ * Compute the sum for each numeric calender interval columns for each group.
321
322
* The resulting `DataFrame` will also contain the grouping columns.
322
323
* When specified columns are given, only compute the sum for them.
323
324
*
324
325
* @since 1.3.0
325
326
*/
326
327
@ scala.annotation.varargs
327
328
def sum (colNames : String * ): DataFrame = {
328
- aggregateNumericColumns (colNames : _* )(Sum )
329
+ aggregateNumericOrIntervalColumns (colNames : _* )(Sum )
329
330
}
330
331
331
332
/**
0 commit comments