Skip to content

Commit 0bf7ac6

Browse files
authored
feat: support aggregation function in window clause. (#10887)
1 parent c7329bf commit 0bf7ac6

File tree

15 files changed

+602
-230
lines changed

15 files changed

+602
-230
lines changed

src/query/expression/src/type_check.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,21 @@ pub fn check_function<Index: ColumnIndex>(
195195
);
196196
}
197197

198+
// Do not check grouping
199+
if name == "grouping" {
200+
debug_assert!(candidates.len() == 1);
201+
let (id, function) = candidates.into_iter().next().unwrap();
202+
let return_type = function.signature.return_type.clone();
203+
return Ok(Expr::FunctionCall {
204+
span,
205+
id,
206+
function,
207+
generics: vec![],
208+
args: args.to_vec(),
209+
return_type,
210+
});
211+
}
212+
198213
let auto_cast_rules = fn_registry.get_auto_cast_rules(name);
199214

200215
let mut fail_resaons = Vec::with_capacity(candidates.len());

src/query/functions/src/aggregates/adaptors/aggregate_null_unary_adaptor.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,31 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction for AggregateNullUnaryAdapto
166166
Ok(())
167167
}
168168

169-
/// we already have accumulate_keys, so we don't need to implement this
170-
fn accumulate_row(&self, _place: StateAddr, _columns: &[Column], _row: usize) -> Result<()> {
171-
unreachable!()
169+
fn accumulate_row(&self, place: StateAddr, columns: &[Column], row: usize) -> Result<()> {
170+
let col = &columns[0];
171+
let validity = column_merge_validity(col, None);
172+
let not_null_columns = vec![col.remove_nullable()];
173+
let not_null_columns = &not_null_columns;
174+
175+
match validity {
176+
Some(v) if v.unset_bits() > 0 => {
177+
// all nulls
178+
if v.unset_bits() == v.len() {
179+
return Ok(());
180+
}
181+
182+
if unsafe { v.get_bit_unchecked(row) } {
183+
self.set_flag(place, 1);
184+
self.nested.accumulate_row(place, not_null_columns, row)?;
185+
}
186+
}
187+
_ => {
188+
self.nested.accumulate_row(place, not_null_columns, row)?;
189+
self.set_flag(place, 1);
190+
}
191+
}
192+
193+
Ok(())
172194
}
173195

174196
fn serialize(&self, place: StateAddr, writer: &mut Vec<u8>) -> Result<()> {

src/query/service/src/pipelines/pipeline_builder.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -735,9 +735,8 @@ impl PipelineBuilder {
735735
})
736736
.collect::<Result<Vec<_>>>()?;
737737

738+
let old_output_len = self.main_pipeline.output_len();
738739
if !partition_by.is_empty() || !order_by.is_empty() {
739-
let old_output_len = self.main_pipeline.output_len();
740-
741740
let mut sort_desc = Vec::with_capacity(partition_by.len() + order_by.len());
742741

743742
for offset in &partition_by {
@@ -757,10 +756,9 @@ impl PipelineBuilder {
757756
}
758757

759758
self.build_sort_pipeline(input_schema.clone(), sort_desc, window.plan_id, None)?;
760-
761-
self.main_pipeline.resize(old_output_len)?;
762759
}
763-
760+
// `TransformWindow` is a pipeline breaker.
761+
self.main_pipeline.resize(1)?;
764762
let func = WindowFunctionInfo::try_create(&window.func, &input_schema)?;
765763
// Window
766764
self.main_pipeline.add_transform(|input, output| {
@@ -773,7 +771,9 @@ impl PipelineBuilder {
773771
window.window_frame.clone(),
774772
)?;
775773
Ok(ProcessorPtr::create(transform))
776-
})
774+
})?;
775+
776+
self.main_pipeline.resize(old_output_len)
777777
}
778778

779779
fn build_sort(&mut self, sort: &Sort) -> Result<()> {

src/query/sql/src/planner/binder/aggregate.rs

Lines changed: 59 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@ use itertools::Itertools;
2929
use super::prune_by_children;
3030
use crate::binder::scalar::ScalarBinder;
3131
use crate::binder::select::SelectList;
32-
use crate::binder::window::WindowFunctionInfo;
3332
use crate::binder::Binder;
3433
use crate::binder::ColumnBinding;
3534
use crate::binder::Visibility;
36-
use crate::binder::WindowOrderByInfo;
3735
use crate::optimizer::SExpr;
3836
use crate::plans::Aggregate;
3937
use crate::plans::AggregateFunction;
@@ -86,16 +84,20 @@ pub struct AggregateInfo {
8684
pub grouping_sets: Vec<Vec<IndexType>>,
8785
}
8886

89-
pub(super) struct AggregateAndWindowRewriter<'a> {
87+
pub(super) struct AggregateRewriter<'a> {
9088
pub bind_context: &'a mut BindContext,
9189
pub metadata: MetadataRef,
90+
// If the aggregate function is in the arguments of window function,
91+
// ignore it here, it will be processed later when analyzing window.
92+
in_window: bool,
9293
}
9394

94-
impl<'a> AggregateAndWindowRewriter<'a> {
95+
impl<'a> AggregateRewriter<'a> {
9596
pub fn new(bind_context: &'a mut BindContext, metadata: MetadataRef) -> Self {
9697
Self {
9798
bind_context,
9899
metadata,
100+
in_window: false,
99101
}
100102
}
101103

@@ -154,8 +156,55 @@ impl<'a> AggregateAndWindowRewriter<'a> {
154156

155157
ScalarExpr::AggregateFunction(agg_func) => self.replace_aggregate_function(agg_func),
156158

157-
// already resolved in `analyze_window_select`
158-
ScalarExpr::WindowFunction(window) => self.replace_window_function(window),
159+
ScalarExpr::WindowFunction(window) => {
160+
self.in_window = true;
161+
162+
let partition_by = window
163+
.partition_by
164+
.iter()
165+
.map(|part| self.visit(part))
166+
.collect::<Result<Vec<_>>>()?;
167+
let order_by = window
168+
.order_by
169+
.iter()
170+
.map(|order| {
171+
Ok(WindowOrderBy {
172+
expr: self.visit(&order.expr)?,
173+
asc: order.asc,
174+
nulls_first: order.nulls_first,
175+
})
176+
})
177+
.collect::<Result<Vec<_>>>()?;
178+
let func = match &window.func {
179+
WindowFuncType::Aggregate(agg) => {
180+
let new_args = agg
181+
.args
182+
.iter()
183+
.map(|arg| self.visit(arg))
184+
.collect::<Result<Vec<_>>>()?;
185+
WindowFuncType::Aggregate(AggregateFunction {
186+
func_name: agg.func_name.clone(),
187+
args: new_args,
188+
display_name: agg.display_name.clone(),
189+
distinct: agg.distinct,
190+
params: agg.params.clone(),
191+
return_type: agg.return_type.clone(),
192+
})
193+
}
194+
func => func.clone(),
195+
};
196+
197+
self.in_window = false;
198+
199+
Ok(WindowFunc {
200+
display_name: window.display_name.clone(),
201+
func,
202+
partition_by,
203+
order_by,
204+
frame: window.frame.clone(),
205+
}
206+
.into())
207+
}
159208
}
160209
}
161210

@@ -272,205 +321,17 @@ impl<'a> AggregateAndWindowRewriter<'a> {
272321

273322
Ok(replaced_func.into())
274323
}
275-
276-
fn replace_window_function(&mut self, window: &WindowFunc) -> Result<ScalarExpr> {
277-
let window_infos = &mut self.bind_context.windows;
278-
279-
let mut replaced_partition_items: Vec<ScalarExpr> =
280-
Vec::with_capacity(window.partition_by.len());
281-
let mut replaced_order_by_items: Vec<WindowOrderBy> =
282-
Vec::with_capacity(window.order_by.len());
283-
let mut agg_args = vec![];
284-
285-
let window_func_name = window.func.func_name();
286-
let func = match &window.func {
287-
WindowFuncType::Aggregate(agg) => {
288-
// resolve aggregate function args in window function.
289-
let mut replaced_args: Vec<ScalarExpr> = Vec::with_capacity(agg.args.len());
290-
for (i, arg) in agg.args.iter().enumerate() {
291-
let name = format!("{}_arg_{}", &window_func_name, i);
292-
if let ScalarExpr::BoundColumnRef(column_ref) = arg {
293-
replaced_args.push(column_ref.clone().into());
294-
agg_args.push(ScalarItem {
295-
index: column_ref.column.index,
296-
scalar: arg.clone(),
297-
});
298-
} else {
299-
let index = self
300-
.metadata
301-
.write()
302-
.add_derived_column(name.clone(), arg.data_type()?);
303-
304-
// Generate a ColumnBinding for each argument of aggregates
305-
let column_binding = ColumnBinding {
306-
database_name: None,
307-
table_name: None,
308-
column_name: name,
309-
index,
310-
data_type: Box::new(arg.data_type()?),
311-
visibility: Visibility::Visible,
312-
};
313-
replaced_args.push(
314-
BoundColumnRef {
315-
span: arg.span(),
316-
column: column_binding.clone(),
317-
}
318-
.into(),
319-
);
320-
agg_args.push(ScalarItem {
321-
index,
322-
scalar: arg.clone(),
323-
});
324-
}
325-
}
326-
WindowFuncType::Aggregate(AggregateFunction {
327-
display_name: agg.display_name.clone(),
328-
func_name: agg.func_name.clone(),
329-
distinct: agg.distinct,
330-
params: agg.params.clone(),
331-
args: replaced_args,
332-
return_type: agg.return_type.clone(),
333-
})
334-
}
335-
func => func.clone(),
336-
};
337-
338-
// resolve partition by
339-
let mut partition_by_items = vec![];
340-
for (i, part) in window.partition_by.iter().enumerate() {
341-
let name = format!("{}_part_{}", &window_func_name, i);
342-
if let ScalarExpr::BoundColumnRef(column_ref) = part {
343-
replaced_partition_items.push(column_ref.clone().into());
344-
partition_by_items.push(ScalarItem {
345-
index: column_ref.column.index,
346-
scalar: part.clone(),
347-
});
348-
} else {
349-
let index = self
350-
.metadata
351-
.write()
352-
.add_derived_column(name.clone(), part.data_type()?);
353-
354-
// Generate a ColumnBinding for each argument of aggregates
355-
let column_binding = ColumnBinding {
356-
database_name: None,
357-
table_name: None,
358-
column_name: name,
359-
index,
360-
data_type: Box::new(part.data_type()?),
361-
visibility: Visibility::Visible,
362-
};
363-
replaced_partition_items.push(
364-
BoundColumnRef {
365-
span: part.span(),
366-
column: column_binding.clone(),
367-
}
368-
.into(),
369-
);
370-
partition_by_items.push(ScalarItem {
371-
index,
372-
scalar: part.clone(),
373-
});
374-
}
375-
}
376-
377-
// resolve order by
378-
let mut order_by_items = vec![];
379-
for (i, order) in window.order_by.iter().enumerate() {
380-
let name = format!("{}_order_{}", &window_func_name, i);
381-
if let ScalarExpr::BoundColumnRef(column_ref) = &order.expr {
382-
replaced_order_by_items.push(WindowOrderBy {
383-
expr: column_ref.clone().into(),
384-
asc: order.asc,
385-
nulls_first: order.nulls_first,
386-
});
387-
order_by_items.push(WindowOrderByInfo {
388-
order_by_item: ScalarItem {
389-
index: column_ref.column.index,
390-
scalar: order.expr.clone(),
391-
},
392-
asc: order.asc,
393-
nulls_first: order.nulls_first,
394-
})
395-
} else {
396-
let index = self
397-
.metadata
398-
.write()
399-
.add_derived_column(name.clone(), order.expr.data_type()?);
400-
401-
// Generate a ColumnBinding for each argument of aggregates
402-
let column_binding = ColumnBinding {
403-
database_name: None,
404-
table_name: None,
405-
column_name: name,
406-
index,
407-
data_type: Box::new(order.expr.data_type()?),
408-
visibility: Visibility::Visible,
409-
};
410-
replaced_order_by_items.push(WindowOrderBy {
411-
expr: BoundColumnRef {
412-
span: order.expr.span(),
413-
column: column_binding,
414-
}
415-
.into(),
416-
asc: order.asc,
417-
nulls_first: order.nulls_first,
418-
});
419-
order_by_items.push(WindowOrderByInfo {
420-
order_by_item: ScalarItem {
421-
index,
422-
scalar: order.expr.clone(),
423-
},
424-
asc: order.asc,
425-
nulls_first: order.nulls_first,
426-
})
427-
}
428-
}
429-
430-
let index = self
431-
.metadata
432-
.write()
433-
.add_derived_column(window.display_name.clone(), window.func.return_type());
434-
435-
// create window info
436-
let window_info = WindowFunctionInfo {
437-
index,
438-
func: func.clone(),
439-
arguments: agg_args,
440-
partition_by_items,
441-
order_by_items,
442-
frame: window.frame.clone(),
443-
};
444-
445-
// push window info to BindContext
446-
window_infos.window_functions.push(window_info);
447-
window_infos.window_functions_map.insert(
448-
window.display_name.clone(),
449-
window_infos.window_functions.len() - 1,
450-
);
451-
452-
let replaced_window = WindowFunc {
453-
display_name: window.display_name.clone(),
454-
func,
455-
partition_by: replaced_partition_items,
456-
order_by: replaced_order_by_items,
457-
frame: window.frame.clone(),
458-
};
459-
460-
Ok(replaced_window.into())
461-
}
462324
}
463-
464325
impl Binder {
465-
/// Analyze aggregates and windows in select clause, this will rewrite aggregate and window functions.
466-
/// See `AggregateRewriter` for more details.
467-
pub(crate) fn analyze_aggregate_and_window_select(
326+
/// Analyze aggregates in select clause, this will rewrite aggregate functions.
327+
/// See [`AggregateRewriter`] for more details.
328+
pub(crate) fn analyze_aggregate_select(
468329
&mut self,
469330
bind_context: &mut BindContext,
470331
select_list: &mut SelectList,
471332
) -> Result<()> {
472333
for item in select_list.items.iter_mut() {
473-
let mut rewriter = AggregateAndWindowRewriter::new(bind_context, self.metadata.clone());
334+
let mut rewriter = AggregateRewriter::new(bind_context, self.metadata.clone());
474335
let new_scalar = rewriter.visit(&item.scalar)?;
475336
item.scalar = new_scalar;
476337
}

0 commit comments

Comments
 (0)