@@ -83,6 +83,7 @@ pub fn return_type(
83
83
Ok ( coerced_data_types[ 0 ] . clone ( ) )
84
84
}
85
85
AggregateFunction :: ApproxMedian => Ok ( coerced_data_types[ 0 ] . clone ( ) ) ,
86
+ AggregateFunction :: BoolAnd | AggregateFunction :: BoolOr => Ok ( DataType :: Boolean ) ,
86
87
}
87
88
}
88
89
@@ -297,6 +298,13 @@ pub fn create_aggregate_expr(
297
298
"MEDIAN(DISTINCT) aggregations are not available" . to_string ( ) ,
298
299
) ) ;
299
300
}
301
+ ( AggregateFunction :: BoolAnd , _) => Arc :: new ( expressions:: BoolAnd :: new (
302
+ coerced_phy_exprs[ 0 ] . clone ( ) ,
303
+ name,
304
+ ) ) ,
305
+ ( AggregateFunction :: BoolOr , _) => {
306
+ Arc :: new ( expressions:: BoolOr :: new ( coerced_phy_exprs[ 0 ] . clone ( ) , name) )
307
+ }
300
308
} )
301
309
}
302
310
@@ -374,16 +382,19 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature {
374
382
. collect ( ) ,
375
383
Volatility :: Immutable ,
376
384
) ,
385
+ AggregateFunction :: BoolAnd | AggregateFunction :: BoolOr => {
386
+ Signature :: exact ( vec ! [ DataType :: Boolean ] , Volatility :: Immutable )
387
+ }
377
388
}
378
389
}
379
390
380
391
#[ cfg( test) ]
381
392
mod tests {
382
393
use super :: * ;
383
394
use crate :: physical_plan:: expressions:: {
384
- ApproxDistinct , ApproxMedian , ApproxPercentileCont , ArrayAgg , Avg , Correlation ,
385
- Count , Covariance , DistinctArrayAgg , DistinctCount , Max , Min , Stddev , Sum ,
386
- Variance ,
395
+ ApproxDistinct , ApproxMedian , ApproxPercentileCont , ArrayAgg , Avg , BoolAnd ,
396
+ BoolOr , Correlation , Count , Covariance , DistinctArrayAgg , DistinctCount , Max ,
397
+ Min , Stddev , Sum , Variance ,
387
398
} ;
388
399
use crate :: { error:: Result , scalar:: ScalarValue } ;
389
400
@@ -995,6 +1006,45 @@ mod tests {
995
1006
Ok ( ( ) )
996
1007
}
997
1008
1009
+ #[ test]
1010
+ fn test_bool_and_or_expr ( ) -> Result < ( ) > {
1011
+ let funcs = vec ! [ AggregateFunction :: BoolAnd , AggregateFunction :: BoolOr ] ;
1012
+ for fun in funcs {
1013
+ let input_schema =
1014
+ Schema :: new ( vec ! [ Field :: new( "c1" , DataType :: Boolean , true ) ] ) ;
1015
+ let input_phy_exprs: Vec < Arc < dyn PhysicalExpr > > = vec ! [ Arc :: new(
1016
+ expressions:: Column :: new_with_schema( "c1" , & input_schema) . unwrap( ) ,
1017
+ ) ] ;
1018
+ let result_agg_phy_exprs = create_aggregate_expr (
1019
+ & fun,
1020
+ false ,
1021
+ & input_phy_exprs[ 0 ..1 ] ,
1022
+ & input_schema,
1023
+ "c1" ,
1024
+ ) ?;
1025
+ match fun {
1026
+ AggregateFunction :: BoolAnd => {
1027
+ assert ! ( result_agg_phy_exprs. as_any( ) . is:: <BoolAnd >( ) ) ;
1028
+ assert_eq ! ( "c1" , result_agg_phy_exprs. name( ) ) ;
1029
+ assert_eq ! (
1030
+ Field :: new( "c1" , DataType :: Boolean , true ) ,
1031
+ result_agg_phy_exprs. field( ) . unwrap( )
1032
+ ) ;
1033
+ }
1034
+ AggregateFunction :: BoolOr => {
1035
+ assert ! ( result_agg_phy_exprs. as_any( ) . is:: <BoolOr >( ) ) ;
1036
+ assert_eq ! ( "c1" , result_agg_phy_exprs. name( ) ) ;
1037
+ assert_eq ! (
1038
+ Field :: new( "c1" , DataType :: Boolean , true ) ,
1039
+ result_agg_phy_exprs. field( ) . unwrap( )
1040
+ ) ;
1041
+ }
1042
+ _ => { }
1043
+ } ;
1044
+ }
1045
+ Ok ( ( ) )
1046
+ }
1047
+
998
1048
#[ test]
999
1049
fn test_median ( ) -> Result < ( ) > {
1000
1050
let observed = return_type ( & AggregateFunction :: ApproxMedian , & [ DataType :: Utf8 ] ) ;
@@ -1158,4 +1208,32 @@ mod tests {
1158
1208
let observed = return_type ( & AggregateFunction :: Stddev , & [ DataType :: Utf8 ] ) ;
1159
1209
assert ! ( observed. is_err( ) ) ;
1160
1210
}
1211
+
1212
+ #[ test]
1213
+ fn test_bool_and_return_type ( ) -> Result < ( ) > {
1214
+ let observed = return_type ( & AggregateFunction :: BoolAnd , & [ DataType :: Boolean ] ) ?;
1215
+ assert_eq ! ( DataType :: Boolean , observed) ;
1216
+
1217
+ Ok ( ( ) )
1218
+ }
1219
+
1220
+ #[ test]
1221
+ fn test_bool_and_no_utf8 ( ) {
1222
+ let observed = return_type ( & AggregateFunction :: BoolAnd , & [ DataType :: Utf8 ] ) ;
1223
+ assert ! ( observed. is_err( ) ) ;
1224
+ }
1225
+
1226
+ #[ test]
1227
+ fn test_bool_or_return_type ( ) -> Result < ( ) > {
1228
+ let observed = return_type ( & AggregateFunction :: BoolOr , & [ DataType :: Boolean ] ) ?;
1229
+ assert_eq ! ( DataType :: Boolean , observed) ;
1230
+
1231
+ Ok ( ( ) )
1232
+ }
1233
+
1234
+ #[ test]
1235
+ fn test_bool_or_no_utf8 ( ) {
1236
+ let observed = return_type ( & AggregateFunction :: BoolOr , & [ DataType :: Utf8 ] ) ;
1237
+ assert ! ( observed. is_err( ) ) ;
1238
+ }
1161
1239
}
0 commit comments