Skip to content

Commit 4e437a0

Browse files
committed
WIP: Add input requirement to topk node
1 parent c8edccd commit 4e437a0

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

rust/cubestore/cubestore/src/queryplanner/topk/execute.rs

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use datafusion::error::DataFusionError;
1111

1212
use datafusion::execution::TaskContext;
1313
use datafusion::logical_expr::Accumulator;
14-
use datafusion::physical_expr::EquivalenceProperties;
14+
use datafusion::physical_expr::{EquivalenceProperties, LexRequirement, PhysicalSortRequirement};
1515
use datafusion::physical_plan::aggregates::{create_accumulators, AccumulatorItem, AggregateMode};
1616
use datafusion::physical_plan::common::collect;
1717
use datafusion::physical_plan::filter::FilterExec;
@@ -56,7 +56,8 @@ pub struct AggregateTopKExec {
5656
/// Always an instance of ClusterSendExec or WorkerExec.
5757
pub cluster: Arc<dyn ExecutionPlan>,
5858
pub schema: SchemaRef,
59-
cache: PlanProperties,
59+
pub cache: PlanProperties,
60+
pub sort_requirement: LexRequirement,
6061
}
6162

6263
/// Third item is the neutral value for the corresponding aggregate function.
@@ -72,6 +73,8 @@ impl AggregateTopKExec {
7273
having: Option<Arc<dyn PhysicalExpr>>,
7374
cluster: Arc<dyn ExecutionPlan>,
7475
schema: SchemaRef,
76+
// sort_requirement is passed in by topk_plan mostly for the sake of code deduplication
77+
sort_requirement: LexRequirement,
7578
) -> AggregateTopKExec {
7679
assert_eq!(schema.fields().len(), agg_expr.len() + key_len);
7780
assert_eq!(agg_fun.len(), agg_expr.len());
@@ -95,6 +98,7 @@ impl AggregateTopKExec {
9598
cluster,
9699
schema,
97100
cache,
101+
sort_requirement,
98102
}
99103
}
100104

@@ -171,13 +175,20 @@ impl ExecutionPlan for AggregateTopKExec {
171175
cluster,
172176
schema: self.schema.clone(),
173177
cache: self.cache.clone(),
178+
sort_requirement: self.sort_requirement.clone(),
174179
}))
175180
}
176181

177182
fn properties(&self) -> &PlanProperties {
178183
&self.cache
179184
}
180185

186+
// TODO upgrade DF: Probably should include output ordering in the PlanProperties.
187+
188+
fn required_input_ordering(&self) -> Vec<Option<LexRequirement>> {
189+
vec![Some(self.sort_requirement.clone())]
190+
}
191+
181192
#[tracing::instrument(level = "trace", skip(self))]
182193
fn execute(
183194
&self,
@@ -996,6 +1007,7 @@ fn finalize_aggregation_into(
9961007
#[cfg(test)]
9971008
mod tests {
9981009
use super::*;
1010+
use crate::queryplanner::topk::plan::make_sort_expr;
9991011
use crate::queryplanner::topk::{AggregateTopKExec, SortColumn};
10001012
use datafusion::arrow::array::{Array, ArrayRef, Int64Array};
10011013
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
@@ -1418,17 +1430,21 @@ mod tests {
14181430
// config: ExecutionConfig::new(),
14191431
// execution_props: ExecutionProps::new(),
14201432
// };
1421-
let agg_exprs = aggs
1433+
let agg_functions = aggs
14221434
.iter()
14231435
.enumerate()
1424-
.map(|(i, f)| Expr::AggregateFunction(AggregateFunction {
1436+
.map(|(i, f)| AggregateFunction {
14251437
func: topk_fun_to_fusion_type(&ctx, f).unwrap(),
14261438
args: vec![Expr::Column(Column::from_name(format!("agg{}", i + 1)))],
14271439
distinct: false,
14281440
filter: None,
14291441
order_by: None,
14301442
null_treatment: None,
1431-
}));
1443+
})
1444+
.collect::<Vec<_>>();
1445+
let agg_exprs = agg_functions.iter().map(|agg_fn|
1446+
Expr::AggregateFunction(agg_fn.clone())
1447+
);
14321448
let physical_agg_exprs: Vec<(AggregateFunctionExpr, Option<Arc<dyn PhysicalExpr>>, Option<Vec<datafusion::physical_expr::PhysicalSortExpr>>)> = agg_exprs
14331449
.map(|e| {
14341450
Ok(create_aggregate_expr_and_maybe_filter(
@@ -1439,7 +1455,7 @@ mod tests {
14391455
)?)
14401456
})
14411457
.collect::<Result<Vec<_>, DataFusionError>>()?;
1442-
let (agg_fn_exprs, agg_phys_exprs, _order_by): (Vec<_>, Vec<_>, Vec<_>) = itertools::multiunzip(physical_agg_exprs);
1458+
let (agg_fn_exprs, _agg_phys_exprs, _order_by): (Vec<_>, Vec<_>, Vec<_>) = itertools::multiunzip(physical_agg_exprs);
14431459

14441460
let output_agg_fields = agg_fn_exprs
14451461
.iter()
@@ -1453,6 +1469,23 @@ mod tests {
14531469
.collect::<Vec<_>>(),
14541470
));
14551471

1472+
let sort_requirement = order_by.iter().map(|c| {
1473+
let i = key_len + c.agg_index;
1474+
PhysicalSortRequirement {
1475+
expr: make_sort_expr(
1476+
&input_schema.inner(),
1477+
&aggs[c.agg_index],
1478+
Arc::new(datafusion::physical_expr::expressions::Column::new(input_schema.field(i).name(), i)),
1479+
&agg_functions[c.agg_index].args,
1480+
&input_schema,
1481+
),
1482+
options: Some(SortOptions {
1483+
descending: !c.asc,
1484+
nulls_first: c.nulls_first,
1485+
}),
1486+
}
1487+
}).collect();
1488+
14561489
Ok(AggregateTopKExec::new(
14571490
limit,
14581491
key_len,
@@ -1462,6 +1495,7 @@ mod tests {
14621495
None,
14631496
Arc::new(EmptyExec::new(input_schema.inner().clone())),
14641497
output_schema,
1498+
sort_requirement,
14651499
))
14661500
}
14671501

rust/cubestore/cubestore/src/queryplanner/topk/plan.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use datafusion::error::DataFusionError;
1212
use datafusion::execution::SessionState;
1313
use datafusion::logical_expr::expr::{AggregateFunction, Alias, ScalarFunction};
1414
use datafusion::logical_expr::expr::physical_name;
15+
use datafusion::physical_expr::PhysicalSortRequirement;
1516
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
1617
use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr};
1718
use datafusion::physical_plan::sorts::sort::SortExec;
@@ -432,6 +433,7 @@ pub fn plan_topk(
432433
}
433434
})
434435
.collect_vec();
436+
let sort_requirement = sort_expr.iter().map(|e| PhysicalSortRequirement::from(e.clone())).collect::<Vec<_>>();
435437
let sort = Arc::new(SortExec::new(sort_expr, aggregate));
436438
let sort_schema = sort.schema();
437439

@@ -461,11 +463,12 @@ pub fn plan_topk(
461463
having,
462464
cluster,
463465
schema,
466+
sort_requirement,
464467
));
465468
Ok(topk_exec)
466469
}
467470

468-
fn make_sort_expr(
471+
pub fn make_sort_expr(
469472
schema: &Arc<Schema>,
470473
fun: &TopKAggregateFunction,
471474
col: Arc<dyn PhysicalExpr>,

0 commit comments

Comments
 (0)