Skip to content

Commit 3f422a1

Browse files
feat: Support FILTER clause in aggregate window functions (#17378)
* feat: Support `FILTER` clause in aggregate window functions * fix: Box `WindowFunction` in `ExprFuncKind` enum to reduce enum total size As suggested by `clippy`: ``` warning: large size difference between variants --> datafusion/expr/src/expr_fn.rs:772:1 | 772 | / pub enum ExprFuncKind { 773 | | Aggregate(AggregateFunction), | | ---------------------------- the second-largest variant contains at least 72 bytes 774 | | Window(WindowFunction), | | ---------------------- the largest variant contains at least 288 bytes 775 | | } | |_^ the entire enum is at least 288 bytes | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#large_enum_variant = note: `#[warn(clippy::large_enum_variant)]` on by default help: consider boxing the large fields to reduce the total size of the enum | 774 - Window(WindowFunction), 774 + Window(Box<WindowFunction>), | ``` * test: Add DataFrame API test for FILTER clause on aggregate window functions * docs: Update aggregate and window function documentation with FILTER support * docs: Link missing proto fields to github issue in TODO comment
1 parent f70ded5 commit 3f422a1

File tree

34 files changed

+619
-116
lines changed

34 files changed

+619
-116
lines changed

datafusion-examples/examples/advanced_udwf.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf {
200200
window_frame: window_function.params.window_frame,
201201
null_treatment: window_function.params.null_treatment,
202202
distinct: window_function.params.distinct,
203+
filter: window_function.params.filter,
203204
},
204205
}))
205206
};

datafusion/common/src/tree_node.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,48 @@ impl<
990990
}
991991
}
992992

993+
impl<
994+
'a,
995+
T: 'a,
996+
C0: TreeNodeContainer<'a, T>,
997+
C1: TreeNodeContainer<'a, T>,
998+
C2: TreeNodeContainer<'a, T>,
999+
C3: TreeNodeContainer<'a, T>,
1000+
> TreeNodeContainer<'a, T> for (C0, C1, C2, C3)
1001+
{
1002+
fn apply_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1003+
&'a self,
1004+
mut f: F,
1005+
) -> Result<TreeNodeRecursion> {
1006+
self.0
1007+
.apply_elements(&mut f)?
1008+
.visit_sibling(|| self.1.apply_elements(&mut f))?
1009+
.visit_sibling(|| self.2.apply_elements(&mut f))?
1010+
.visit_sibling(|| self.3.apply_elements(&mut f))
1011+
}
1012+
1013+
fn map_elements<F: FnMut(T) -> Result<Transformed<T>>>(
1014+
self,
1015+
mut f: F,
1016+
) -> Result<Transformed<Self>> {
1017+
self.0
1018+
.map_elements(&mut f)?
1019+
.map_data(|new_c0| Ok((new_c0, self.1, self.2, self.3)))?
1020+
.transform_sibling(|(new_c0, c1, c2, c3)| {
1021+
c1.map_elements(&mut f)?
1022+
.map_data(|new_c1| Ok((new_c0, new_c1, c2, c3)))
1023+
})?
1024+
.transform_sibling(|(new_c0, new_c1, c2, c3)| {
1025+
c2.map_elements(&mut f)?
1026+
.map_data(|new_c2| Ok((new_c0, new_c1, new_c2, c3)))
1027+
})?
1028+
.transform_sibling(|(new_c0, new_c1, new_c2, c3)| {
1029+
c3.map_elements(&mut f)?
1030+
.map_data(|new_c3| Ok((new_c0, new_c1, new_c2, new_c3)))
1031+
})
1032+
}
1033+
}
1034+
9931035
/// [`TreeNodeRefContainer`] contains references to elements that a function can be
9941036
/// applied on. The elements of the container are siblings so the continuation rules are
9951037
/// similar to [`TreeNodeRecursion::visit_sibling`].
@@ -1065,6 +1107,27 @@ impl<
10651107
}
10661108
}
10671109

1110+
impl<
1111+
'a,
1112+
T: 'a,
1113+
C0: TreeNodeContainer<'a, T>,
1114+
C1: TreeNodeContainer<'a, T>,
1115+
C2: TreeNodeContainer<'a, T>,
1116+
C3: TreeNodeContainer<'a, T>,
1117+
> TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2, &'a C3)
1118+
{
1119+
fn apply_ref_elements<F: FnMut(&'a T) -> Result<TreeNodeRecursion>>(
1120+
&self,
1121+
mut f: F,
1122+
) -> Result<TreeNodeRecursion> {
1123+
self.0
1124+
.apply_elements(&mut f)?
1125+
.visit_sibling(|| self.1.apply_elements(&mut f))?
1126+
.visit_sibling(|| self.2.apply_elements(&mut f))?
1127+
.visit_sibling(|| self.3.apply_elements(&mut f))
1128+
}
1129+
}
1130+
10681131
/// Transformation helper to process a sequence of iterable tree nodes that are siblings.
10691132
pub trait TreeNodeIterator: Iterator {
10701133
/// Apples `f` to each item in this iterator

datafusion/core/src/physical_planner.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,7 @@ pub fn create_window_expr_with_name(
16501650
window_frame,
16511651
null_treatment,
16521652
distinct,
1653+
filter,
16531654
},
16541655
} = window_fun.as_ref();
16551656
let physical_args =
@@ -1669,6 +1670,11 @@ pub fn create_window_expr_with_name(
16691670
let window_frame = Arc::new(window_frame.clone());
16701671
let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls)
16711672
== NullTreatment::IgnoreNulls;
1673+
let physical_filter = filter
1674+
.as_ref()
1675+
.map(|f| create_physical_expr(f, logical_schema, execution_props))
1676+
.transpose()?;
1677+
16721678
windows::create_window_expr(
16731679
fun,
16741680
name,
@@ -1679,6 +1685,7 @@ pub fn create_window_expr_with_name(
16791685
physical_schema,
16801686
ignore_nulls,
16811687
*distinct,
1688+
physical_filter,
16821689
)
16831690
}
16841691
other => plan_err!("Invalid window expression '{other:?}'"),

datafusion/core/tests/dataframe/mod.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,83 @@ async fn window_using_aggregates() -> Result<()> {
959959
Ok(())
960960
}
961961

962+
#[tokio::test]
963+
async fn window_aggregates_with_filter() -> Result<()> {
964+
// Define a small in-memory table to make expected values clear
965+
let ts: Int32Array = [1, 2, 3, 4, 5].into_iter().collect();
966+
let val: Int32Array = [-3, -2, 1, 4, -1].into_iter().collect();
967+
let batch = RecordBatch::try_from_iter(vec![
968+
("ts", Arc::new(ts) as _),
969+
("val", Arc::new(val) as _),
970+
])?;
971+
972+
let ctx = SessionContext::new();
973+
ctx.register_batch("t", batch)?;
974+
975+
let df = ctx.table("t").await?;
976+
977+
// Build filtered window aggregates over ORDER BY ts ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
978+
let mut exprs = vec![
979+
(datafusion_functions_aggregate::sum::sum_udaf(), "sum_pos"),
980+
(
981+
datafusion_functions_aggregate::average::avg_udaf(),
982+
"avg_pos",
983+
),
984+
(
985+
datafusion_functions_aggregate::min_max::min_udaf(),
986+
"min_pos",
987+
),
988+
(
989+
datafusion_functions_aggregate::min_max::max_udaf(),
990+
"max_pos",
991+
),
992+
(
993+
datafusion_functions_aggregate::count::count_udaf(),
994+
"cnt_pos",
995+
),
996+
]
997+
.into_iter()
998+
.map(|(func, alias)| {
999+
let w = WindowFunction::new(
1000+
WindowFunctionDefinition::AggregateUDF(func),
1001+
vec![col("val")],
1002+
);
1003+
1004+
Expr::from(w)
1005+
.order_by(vec![col("ts").sort(true, true)])
1006+
.window_frame(WindowFrame::new_bounds(
1007+
WindowFrameUnits::Rows,
1008+
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
1009+
WindowFrameBound::CurrentRow,
1010+
))
1011+
.filter(col("val").gt(lit(0)))
1012+
.build()
1013+
.unwrap()
1014+
.alias(alias)
1015+
})
1016+
.collect::<Vec<_>>();
1017+
exprs.extend_from_slice(&[col("ts"), col("val")]);
1018+
1019+
let results = df.select(exprs)?.collect().await?;
1020+
1021+
assert_snapshot!(
1022+
batches_to_string(&results),
1023+
@r###"
1024+
+---------+---------+---------+---------+---------+----+-----+
1025+
| sum_pos | avg_pos | min_pos | max_pos | cnt_pos | ts | val |
1026+
+---------+---------+---------+---------+---------+----+-----+
1027+
| | | | | 0 | 1 | -3 |
1028+
| | | | | 0 | 2 | -2 |
1029+
| 1 | 1.0 | 1 | 1 | 1 | 3 | 1 |
1030+
| 5 | 2.5 | 1 | 4 | 2 | 4 | 4 |
1031+
| 5 | 2.5 | 1 | 4 | 2 | 5 | -1 |
1032+
+---------+---------+---------+---------+---------+----+-----+
1033+
"###
1034+
);
1035+
1036+
Ok(())
1037+
}
1038+
9621039
// Test issue: https://github.com/apache/datafusion/issues/10346
9631040
#[tokio::test]
9641041
async fn test_select_over_aggregate_schema() -> Result<()> {

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
289289
&extended_schema,
290290
false,
291291
false,
292+
None,
292293
)?;
293294
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
294295
vec![window_expr],
@@ -662,6 +663,7 @@ async fn run_window_test(
662663
&extended_schema,
663664
false,
664665
false,
666+
None,
665667
)?],
666668
exec1,
667669
false,
@@ -681,6 +683,7 @@ async fn run_window_test(
681683
&extended_schema,
682684
false,
683685
false,
686+
None,
684687
)?],
685688
exec2,
686689
search_mode.clone(),

datafusion/core/tests/physical_optimizer/enforce_sorting.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3686,6 +3686,7 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> {
36863686
input_schema.as_ref(),
36873687
false,
36883688
false,
3689+
None,
36893690
)?;
36903691
let window_exec = if window_expr.uses_bounded_memory() {
36913692
Arc::new(BoundedWindowAggExec::try_new(

datafusion/core/tests/physical_optimizer/test_utils.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ pub fn bounded_window_exec_with_partition(
266266
schema.as_ref(),
267267
false,
268268
false,
269+
None,
269270
)
270271
.unwrap();
271272

datafusion/core/tests/physical_optimizer/window_optimize.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ mod test {
4949
&partition,
5050
&[],
5151
Arc::new(frame),
52+
None,
5253
);
5354

5455
let bounded_agg_exec = BoundedWindowAggExec::try_new(

datafusion/expr/src/expr.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,8 @@ pub struct WindowFunctionParams {
12301230
pub order_by: Vec<Sort>,
12311231
/// Window frame
12321232
pub window_frame: WindowFrame,
1233+
/// Optional filter expression (FILTER (WHERE ...))
1234+
pub filter: Option<Box<Expr>>,
12331235
/// Specifies how NULL value is treated: ignore or respect
12341236
pub null_treatment: Option<NullTreatment>,
12351237
/// Distinct flag
@@ -1247,6 +1249,7 @@ impl WindowFunction {
12471249
partition_by: Vec::default(),
12481250
order_by: Vec::default(),
12491251
window_frame: WindowFrame::new(None),
1252+
filter: None,
12501253
null_treatment: None,
12511254
distinct: false,
12521255
},
@@ -2388,6 +2391,7 @@ impl NormalizeEq for Expr {
23882391
window_frame: self_window_frame,
23892392
partition_by: self_partition_by,
23902393
order_by: self_order_by,
2394+
filter: self_filter,
23912395
null_treatment: self_null_treatment,
23922396
distinct: self_distinct,
23932397
},
@@ -2400,13 +2404,19 @@ impl NormalizeEq for Expr {
24002404
window_frame: other_window_frame,
24012405
partition_by: other_partition_by,
24022406
order_by: other_order_by,
2407+
filter: other_filter,
24032408
null_treatment: other_null_treatment,
24042409
distinct: other_distinct,
24052410
},
24062411
} = other.as_ref();
24072412

24082413
self_fun.name() == other_fun.name()
24092414
&& self_window_frame == other_window_frame
2415+
&& match (self_filter, other_filter) {
2416+
(Some(a), Some(b)) => a.normalize_eq(b),
2417+
(None, None) => true,
2418+
_ => false,
2419+
}
24102420
&& self_null_treatment == other_null_treatment
24112421
&& self_args.len() == other_args.len()
24122422
&& self_args
@@ -2658,12 +2668,14 @@ impl HashNode for Expr {
26582668
partition_by: _,
26592669
order_by: _,
26602670
window_frame,
2671+
filter,
26612672
null_treatment,
26622673
distinct,
26632674
},
26642675
} = window_fun.as_ref();
26652676
fun.hash(state);
26662677
window_frame.hash(state);
2678+
filter.hash(state);
26672679
null_treatment.hash(state);
26682680
distinct.hash(state);
26692681
}
@@ -2967,6 +2979,7 @@ impl Display for SchemaDisplay<'_> {
29672979
partition_by,
29682980
order_by,
29692981
window_frame,
2982+
filter,
29702983
null_treatment,
29712984
distinct,
29722985
} = params;
@@ -2993,6 +3006,10 @@ impl Display for SchemaDisplay<'_> {
29933006
write!(f, " {null_treatment}")?;
29943007
}
29953008

3009+
if let Some(filter) = filter {
3010+
write!(f, " FILTER (WHERE {filter})")?;
3011+
}
3012+
29963013
if !partition_by.is_empty() {
29973014
write!(
29983015
f,
@@ -3370,6 +3387,7 @@ impl Display for Expr {
33703387
partition_by,
33713388
order_by,
33723389
window_frame,
3390+
filter,
33733391
null_treatment,
33743392
distinct,
33753393
} = params;
@@ -3380,6 +3398,10 @@ impl Display for Expr {
33803398
write!(f, "{nt}")?;
33813399
}
33823400

3401+
if let Some(fe) = filter {
3402+
write!(f, " FILTER (WHERE {fe})")?;
3403+
}
3404+
33833405
if !partition_by.is_empty() {
33843406
write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?;
33853407
}

0 commit comments

Comments
 (0)