Skip to content

Commit 66fee4a

Browse files
committed
feat: Add bool_and, bool_or aggregate functions
1 parent 59a2174 commit 66fee4a

File tree

9 files changed

+506
-3
lines changed

9 files changed

+506
-3
lines changed

datafusion/core/src/physical_plan/aggregates.rs

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ pub fn return_type(
8383
Ok(coerced_data_types[0].clone())
8484
}
8585
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
86+
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => Ok(DataType::Boolean),
8687
}
8788
}
8889

@@ -297,6 +298,13 @@ pub fn create_aggregate_expr(
297298
"MEDIAN(DISTINCT) aggregations are not available".to_string(),
298299
));
299300
}
301+
(AggregateFunction::BoolAnd, _) => Arc::new(expressions::BoolAnd::new(
302+
coerced_phy_exprs[0].clone(),
303+
name,
304+
)),
305+
(AggregateFunction::BoolOr, _) => {
306+
Arc::new(expressions::BoolOr::new(coerced_phy_exprs[0].clone(), name))
307+
}
300308
})
301309
}
302310

@@ -374,16 +382,19 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature {
374382
.collect(),
375383
Volatility::Immutable,
376384
),
385+
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
386+
Signature::exact(vec![DataType::Boolean], Volatility::Immutable)
387+
}
377388
}
378389
}
379390

380391
#[cfg(test)]
381392
mod tests {
382393
use super::*;
383394
use crate::physical_plan::expressions::{
384-
ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, Correlation,
385-
Count, Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum,
386-
Variance,
395+
ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, BoolAnd,
396+
BoolOr, Correlation, Count, Covariance, DistinctArrayAgg, DistinctCount, Max,
397+
Min, Stddev, Sum, Variance,
387398
};
388399
use crate::{error::Result, scalar::ScalarValue};
389400

@@ -995,6 +1006,45 @@ mod tests {
9951006
Ok(())
9961007
}
9971008

1009+
#[test]
1010+
fn test_bool_and_or_expr() -> Result<()> {
1011+
let funcs = vec![AggregateFunction::BoolAnd, AggregateFunction::BoolOr];
1012+
for fun in funcs {
1013+
let input_schema =
1014+
Schema::new(vec![Field::new("c1", DataType::Boolean, true)]);
1015+
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
1016+
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
1017+
)];
1018+
let result_agg_phy_exprs = create_aggregate_expr(
1019+
&fun,
1020+
false,
1021+
&input_phy_exprs[0..1],
1022+
&input_schema,
1023+
"c1",
1024+
)?;
1025+
match fun {
1026+
AggregateFunction::BoolAnd => {
1027+
assert!(result_agg_phy_exprs.as_any().is::<BoolAnd>());
1028+
assert_eq!("c1", result_agg_phy_exprs.name());
1029+
assert_eq!(
1030+
Field::new("c1", DataType::Boolean, true),
1031+
result_agg_phy_exprs.field().unwrap()
1032+
);
1033+
}
1034+
AggregateFunction::BoolOr => {
1035+
assert!(result_agg_phy_exprs.as_any().is::<BoolOr>());
1036+
assert_eq!("c1", result_agg_phy_exprs.name());
1037+
assert_eq!(
1038+
Field::new("c1", DataType::Boolean, true),
1039+
result_agg_phy_exprs.field().unwrap()
1040+
);
1041+
}
1042+
_ => {}
1043+
};
1044+
}
1045+
Ok(())
1046+
}
1047+
9981048
#[test]
9991049
fn test_median() -> Result<()> {
10001050
let observed = return_type(&AggregateFunction::ApproxMedian, &[DataType::Utf8]);
@@ -1158,4 +1208,32 @@ mod tests {
11581208
let observed = return_type(&AggregateFunction::Stddev, &[DataType::Utf8]);
11591209
assert!(observed.is_err());
11601210
}
1211+
1212+
#[test]
1213+
fn test_bool_and_return_type() -> Result<()> {
1214+
let observed = return_type(&AggregateFunction::BoolAnd, &[DataType::Boolean])?;
1215+
assert_eq!(DataType::Boolean, observed);
1216+
1217+
Ok(())
1218+
}
1219+
1220+
#[test]
1221+
fn test_bool_and_no_utf8() {
1222+
let observed = return_type(&AggregateFunction::BoolAnd, &[DataType::Utf8]);
1223+
assert!(observed.is_err());
1224+
}
1225+
1226+
#[test]
1227+
fn test_bool_or_return_type() -> Result<()> {
1228+
let observed = return_type(&AggregateFunction::BoolOr, &[DataType::Boolean])?;
1229+
assert_eq!(DataType::Boolean, observed);
1230+
1231+
Ok(())
1232+
}
1233+
1234+
#[test]
1235+
fn test_bool_or_no_utf8() {
1236+
let observed = return_type(&AggregateFunction::BoolOr, &[DataType::Utf8]);
1237+
assert!(observed.is_err());
1238+
}
11611239
}

datafusion/expr/src/aggregate_function.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ pub enum AggregateFunction {
5757
ApproxPercentileContWithWeight,
5858
/// ApproxMedian
5959
ApproxMedian,
60+
/// BoolAnd
61+
BoolAnd,
62+
/// BoolOr
63+
BoolOr,
6064
}
6165

6266
impl fmt::Display for AggregateFunction {
@@ -92,6 +96,8 @@ impl FromStr for AggregateFunction {
9296
AggregateFunction::ApproxPercentileContWithWeight
9397
}
9498
"approx_median" => AggregateFunction::ApproxMedian,
99+
"bool_and" => AggregateFunction::BoolAnd,
100+
"bool_or" => AggregateFunction::BoolOr,
95101
_ => {
96102
return Err(DataFusionError::Plan(format!(
97103
"There is no built-in function named {}",

datafusion/expr/src/expr_fn.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,24 @@ pub fn approx_percentile_cont_with_weight(
179179
}
180180
}
181181

182+
/// Create an expression to represent the bool_and() aggregate function
183+
pub fn bool_and(expr: Expr) -> Expr {
184+
Expr::AggregateFunction {
185+
fun: aggregate_function::AggregateFunction::BoolAnd,
186+
distinct: false,
187+
args: vec![expr],
188+
}
189+
}
190+
191+
/// Create an expression to represent the bool_or() aggregate function
192+
pub fn bool_or(expr: Expr) -> Expr {
193+
Expr::AggregateFunction {
194+
fun: aggregate_function::AggregateFunction::BoolOr,
195+
distinct: false,
196+
args: vec![expr],
197+
}
198+
}
199+
182200
// TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many
183201
// varying arity functions
184202
/// Create an convenience function representing a unary scalar function

datafusion/physical-expr/src/coercion_rule/aggregate_rule.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,15 @@ pub fn coerce_types(
182182
}
183183
Ok(input_types.to_vec())
184184
}
185+
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
186+
if input_types[0] != DataType::Boolean {
187+
return Err(DataFusionError::Plan(format!(
188+
"The function {:?} does not support inputs of type {:?}.",
189+
agg_fun, input_types[0]
190+
)));
191+
}
192+
Ok(input_types.to_vec())
193+
}
185194
}
186195
}
187196

0 commit comments

Comments
 (0)