diff --git a/datafusion/core/benches/data_utils/mod.rs b/datafusion/core/benches/data_utils/mod.rs index 630bc056600b4..e9d705d6638b1 100644 --- a/datafusion/core/benches/data_utils/mod.rs +++ b/datafusion/core/benches/data_utils/mod.rs @@ -26,9 +26,9 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::MemTable; use datafusion::error::Result; use datafusion_common::DataFusionError; -use rand::prelude::IndexedRandom; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; +use rand::{prelude::IndexedRandom, seq::SliceRandom}; use rand_distr::Distribution; use rand_distr::{Normal, Pareto}; use std::fmt::Write; @@ -276,3 +276,43 @@ fn test_schema(use_view: bool) -> SchemaRef { ])) } } + +/// Create deterministic data for DISTINCT benchmarks with predictable trace_ids +/// This ensures consistent results across benchmark runs +#[allow(dead_code)] +pub(crate) fn make_distinct_data( + partition_cnt: i32, + sample_cnt: i32, +) -> Result<(Arc, Vec>), DataFusionError> { + let mut rng = rand::rngs::SmallRng::from_seed([42; 32]); + let total_samples = partition_cnt as usize * sample_cnt as usize; + let mut ids = Vec::new(); + for i in 0..total_samples { + ids.push(i as i64); + } + ids.shuffle(&mut rng); + + let mut global_idx = 0; + let schema = test_distinct_schema(); + let mut partitions = vec![]; + for _ in 0..partition_cnt { + let mut id_builder = Int64Builder::new(); + + for _ in 0..sample_cnt { + let id = ids[global_idx]; + id_builder.append_value(id); + global_idx += 1; + } + + let id_col = Arc::new(id_builder.finish()); + let batch = RecordBatch::try_new(schema.clone(), vec![id_col])?; + partitions.push(vec![batch]); + } + + Ok((schema, partitions)) +} + +/// Returns a Schema for distinct benchmarks with i64 trace_id +fn test_distinct_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])) +} diff --git a/datafusion/core/benches/topk_aggregate.rs b/datafusion/core/benches/topk_aggregate.rs index 7979efdec605e..fd4d762cda6c8 100644 --- a/datafusion/core/benches/topk_aggregate.rs +++ b/datafusion/core/benches/topk_aggregate.rs @@ -28,6 +28,8 @@ use std::hint::black_box; use std::sync::Arc; use tokio::runtime::Runtime; +use crate::data_utils::make_distinct_data; + const LIMIT: usize = 10; async fn create_context( @@ -50,6 +52,25 @@ async fn create_context( Ok(ctx) } +async fn create_context_distinct( + partition_cnt: i32, + sample_cnt: i32, + use_topk: bool, +) -> Result { + // Use deterministic data generation for DISTINCT queries to ensure consistent results + let (schema, parts) = make_distinct_data(partition_cnt, sample_cnt).unwrap(); + let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap()); + + // Create the DataFrame + let mut cfg = SessionConfig::new(); + let opts = cfg.options_mut(); + opts.optimizer.enable_topk_aggregation = use_topk; + let ctx = SessionContext::new_with_config(cfg); + let _ = ctx.register_table("traces", mem_table)?; + + Ok(ctx) +} + fn run(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool, asc: bool) { black_box(rt.block_on(async { aggregate(ctx, limit, use_topk, asc).await })).unwrap(); } @@ -59,6 +80,17 @@ fn run_string(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool) { .unwrap(); } +fn run_distinct( + rt: &Runtime, + ctx: SessionContext, + limit: usize, + use_topk: bool, + asc: bool, +) { + black_box(rt.block_on(async { aggregate_distinct(ctx, limit, use_topk, asc).await })) + .unwrap(); +} + async fn aggregate( ctx: SessionContext, limit: usize, @@ -133,6 +165,84 @@ async fn aggregate_string( Ok(()) } +async fn aggregate_distinct( + ctx: SessionContext, + limit: usize, + use_topk: bool, + asc: bool, +) -> Result<()> { + let order_direction = if asc { "asc" } else { "desc" }; + let sql = format!( + "select id from traces group by id order by id {order_direction} limit {limit};" + ); + let df = ctx.sql(sql.as_str()).await?; + let plan = df.create_physical_plan().await?; + let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string(); + assert_eq!( + actual_phys_plan.contains(&format!("lim=[{limit}]")), + use_topk + ); + let batches = collect(plan, ctx.task_ctx()).await?; + assert_eq!(batches.len(), 1); + let batch = batches.first().unwrap(); + assert_eq!(batch.num_rows(), 10); + + let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase(); + + let expected_asc = r#" ++----+ +| id | ++----+ +| 0 | +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | +| 6 | +| 7 | +| 8 | +| 9 | ++----+ +"# + .trim(); + + let expected_desc = r#" ++---------+ +| id | ++---------+ +| 9999999 | +| 9999998 | +| 9999997 | +| 9999996 | +| 9999995 | +| 9999994 | +| 9999993 | +| 9999992 | +| 9999991 | +| 9999990 | ++---------+ +"# + .trim(); + + // Verify exact results match expected values + if asc { + assert_eq!( + actual.trim(), + expected_asc, + "Ascending DISTINCT results do not match expected values" + ); + } else { + assert_eq!( + actual.trim(), + expected_desc, + "Descending DISTINCT results do not match expected values" + ); + } + + Ok(()) +} + fn criterion_benchmark(c: &mut Criterion) { let rt = Runtime::new().unwrap(); let limit = LIMIT; @@ -253,6 +363,37 @@ fn criterion_benchmark(c: &mut Criterion) { .as_str(), |b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)), ); + + // DISTINCT benchmarks + let ctx = rt.block_on(async { + create_context_distinct(partitions, samples, false) + .await + .unwrap() + }); + c.bench_function( + format!("distinct {} rows desc [no TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx.clone(), limit, false, false)), + ); + + c.bench_function( + format!("distinct {} rows asc [no TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx.clone(), limit, false, true)), + ); + + let ctx_topk = rt.block_on(async { + create_context_distinct(partitions, samples, true) + .await + .unwrap() + }); + c.bench_function( + format!("distinct {} rows desc [TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx_topk.clone(), limit, true, false)), + ); + + c.bench_function( + format!("distinct {} rows asc [TopK]", partitions * samples).as_str(), + |b| b.iter(|| run_distinct(&rt, ctx_topk.clone(), limit, true, true)), + ); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/tests/execution/coop.rs b/datafusion/core/tests/execution/coop.rs index 380a47505ac2d..f19e830ea6ef3 100644 --- a/datafusion/core/tests/execution/coop.rs +++ b/datafusion/core/tests/execution/coop.rs @@ -24,7 +24,7 @@ use datafusion::physical_expr::aggregate::AggregateExprBuilder; use datafusion::physical_plan; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; use datafusion::physical_plan::execution_plan::Boundedness; use datafusion::prelude::SessionContext; @@ -233,6 +233,7 @@ async fn agg_grouped_topk_yields( #[values(false, true)] pretend_infinite: bool, ) -> Result<(), Box> { // build session + let session_ctx = SessionContext::new(); // set up a top-k aggregation @@ -260,7 +261,7 @@ async fn agg_grouped_topk_yields( inf.clone(), inf.schema(), )? - .with_limit(Some(100)), + .with_limit_options(Some(LimitOptions::new(100))), ); query_yields(aggr, session_ctx.task_ctx()).await diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs index 2fdfece2a86e7..9e63c341c92d9 100644 --- a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -37,7 +37,7 @@ use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; use datafusion_physical_plan::displayable; use datafusion_physical_plan::repartition::RepartitionExec; @@ -260,7 +260,7 @@ fn aggregations_with_limit_combined() -> datafusion_common::Result<()> { schema, ) .unwrap() - .with_limit(Some(5)), + .with_limit_options(Some(LimitOptions::new(5))), ); let plan: Arc = final_agg; // should combine the Partial/Final AggregateExecs to a Single AggregateExec diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs index 782e0754b7d27..6d8e7995c18c2 100644 --- a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -98,7 +98,9 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { Arc::clone(input_agg_exec.input()), input_agg_exec.input_schema(), ) - .map(|combined_agg| combined_agg.with_limit(agg_exec.limit())) + .map(|combined_agg| { + combined_agg.with_limit_options(agg_exec.limit_options()) + }) .ok() .map(Arc::new) } else { diff --git a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs index 671d247cf36a5..fe9636f67619b 100644 --- a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs +++ b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs @@ -20,7 +20,7 @@ use std::sync::Arc; -use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::aggregates::{AggregateExec, LimitOptions}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -63,7 +63,7 @@ impl LimitedDistinctAggregation { aggr.input_schema(), ) .expect("Unable to copy Aggregate!") - .with_limit(Some(limit)); + .with_limit_options(Some(LimitOptions::new(limit))); Some(Arc::new(new_aggr)) } diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index 7b2983ee71996..cec6bd70a2089 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -25,6 +25,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::aggregates::LimitOptions; use datafusion_physical_plan::aggregates::{AggregateExec, topk_types_supported}; use datafusion_physical_plan::execution_plan::CardinalityEffect; use datafusion_physical_plan::projection::ProjectionExec; @@ -47,28 +48,47 @@ impl TopKAggregation { order_desc: bool, limit: usize, ) -> Option> { - // ensure the sort direction matches aggregate function - let (field, desc) = aggr.get_minmax_desc()?; - if desc != order_desc { - return None; - } - let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?; - let kt = group_key.0.data_type(&aggr.input().schema()).ok()?; - let vt = field.data_type(); - if !topk_types_supported(&kt, vt) { + // Current only support single group key + let (group_key, group_key_alias) = + aggr.group_expr().expr().iter().exactly_one().ok()?; + let kt = group_key.data_type(&aggr.input().schema()).ok()?; + let vt = if let Some((field, _)) = aggr.get_minmax_desc() { + field.data_type().clone() + } else { + kt.clone() + }; + if !topk_types_supported(&kt, &vt) { return None; } if aggr.filter_expr().iter().any(|e| e.is_some()) { return None; } - // ensure the sort is on the same field as the aggregate output - if order_by != field.name() { + // Check if this is ordering by an aggregate function (MIN/MAX) + if let Some((field, desc)) = aggr.get_minmax_desc() { + // ensure the sort direction matches aggregate function + if desc != order_desc { + return None; + } + // ensure the sort is on the same field as the aggregate output + if order_by != field.name() { + return None; + } + } else if aggr.aggr_expr().is_empty() { + // This is a GROUP BY without aggregates, check if ordering is on the group key itself + if order_by != group_key_alias { + return None; + } + } else { + // Has aggregates but not MIN/MAX, or doesn't DISTINCT return None; } // We found what we want: clone, copy the limit down, and return modified node - let new_aggr = aggr.with_new_limit(Some(limit)); + let new_aggr = AggregateExec::with_new_limit_options( + aggr, + Some(LimitOptions::new_with_order(limit, order_desc)), + ); Some(Arc::new(new_aggr)) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 4dd9482ac4322..d645f5c55d434 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -502,6 +502,42 @@ enum DynamicFilterAggregateType { Max, } +/// Configuration for limit-based optimizations in aggregation +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct LimitOptions { + /// The maximum number of rows to return + pub limit: usize, + /// Optional ordering direction (true = descending, false = ascending) + /// This is used for TopK aggregation to maintain a priority queue with the correct ordering + pub descending: Option, +} + +impl LimitOptions { + /// Create a new LimitOptions with a limit and no specific ordering + pub fn new(limit: usize) -> Self { + Self { + limit, + descending: None, + } + } + + /// Create a new LimitOptions with a limit and ordering direction + pub fn new_with_order(limit: usize, descending: bool) -> Self { + Self { + limit, + descending: Some(descending), + } + } + + pub fn limit(&self) -> usize { + self.limit + } + + pub fn descending(&self) -> Option { + self.descending + } +} + /// Hash aggregate execution plan #[derive(Debug, Clone)] pub struct AggregateExec { @@ -513,8 +549,8 @@ pub struct AggregateExec { aggr_expr: Vec>, /// FILTER (WHERE clause) expression for each aggregate expression filter_expr: Vec>>, - /// Set if the output of this aggregation is truncated by a upstream sort/limit clause - limit: Option, + /// Configuration for limit-based optimizations + limit_options: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub input: Arc, /// Schema after the aggregate is applied @@ -558,7 +594,7 @@ impl AggregateExec { mode: self.mode, group_by: self.group_by.clone(), filter_expr: self.filter_expr.clone(), - limit: self.limit, + limit_options: self.limit_options, input: Arc::clone(&self.input), schema: Arc::clone(&self.schema), input_schema: Arc::clone(&self.input_schema), @@ -567,9 +603,9 @@ impl AggregateExec { } /// Clone this exec, overriding only the limit hint. - pub fn with_new_limit(&self, limit: Option) -> Self { + pub fn with_new_limit_options(&self, limit_options: Option) -> Self { Self { - limit, + limit_options, // clone the rest of the fields required_input_ordering: self.required_input_ordering.clone(), metrics: ExecutionPlanMetricsSet::new(), @@ -709,7 +745,7 @@ impl AggregateExec { input_schema, metrics: ExecutionPlanMetricsSet::new(), required_input_ordering, - limit: None, + limit_options: None, input_order_mode, cache, dynamic_filter: None, @@ -725,11 +761,17 @@ impl AggregateExec { &self.mode } - /// Set the `limit` of this AggExec - pub fn with_limit(mut self, limit: Option) -> Self { - self.limit = limit; + /// Set the limit options for this AggExec + pub fn with_limit_options(mut self, limit_options: Option) -> Self { + self.limit_options = limit_options; self } + + /// Get the limit options (if set) + pub fn limit_options(&self) -> Option { + self.limit_options + } + /// Grouping expressions pub fn group_expr(&self) -> &PhysicalGroupBy { &self.group_by @@ -760,11 +802,6 @@ impl AggregateExec { Arc::clone(&self.input_schema) } - /// number of rows soft limit of the AggregateExec - pub fn limit(&self) -> Option { - self.limit - } - fn execute_typed( &self, partition: usize, @@ -777,11 +814,11 @@ impl AggregateExec { } // grouping by an expression that has a sort/limit upstream - if let Some(limit) = self.limit + if let Some(config) = self.limit_options && !self.is_unordered_unfiltered_group_by_distinct() { return Ok(StreamType::GroupedPriorityQueue( - GroupedTopKAggregateStream::new(self, context, partition, limit)?, + GroupedTopKAggregateStream::new(self, context, partition, config.limit)?, )); } @@ -802,6 +839,13 @@ impl AggregateExec { /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule /// on an AggregateExec. pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { + if self + .limit_options() + .and_then(|config| config.descending) + .is_some() + { + return false; + } // ensure there is a group by if self.group_expr().is_empty() && !self.group_expr().has_grouping_set() { return false; @@ -1119,8 +1163,8 @@ impl DisplayAs for AggregateExec { .map(|agg| agg.name().to_string()) .collect(); write!(f, ", aggr=[{}]", a.join(", "))?; - if let Some(limit) = self.limit { - write!(f, ", lim=[{limit}]")?; + if let Some(config) = self.limit_options { + write!(f, ", lim=[{}]", config.limit)?; } if self.input_order_mode != InputOrderMode::Linear { @@ -1179,6 +1223,9 @@ impl DisplayAs for AggregateExec { if !a.is_empty() { writeln!(f, "aggr={}", a.join(", "))?; } + if let Some(config) = self.limit_options { + writeln!(f, "limit={}", config.limit)?; + } } } Ok(()) @@ -1247,7 +1294,7 @@ impl ExecutionPlan for AggregateExec { Arc::clone(&self.input_schema), Arc::clone(&self.schema), )?; - me.limit = self.limit; + me.limit_options = self.limit_options; me.dynamic_filter = self.dynamic_filter.clone(); Ok(Arc::new(me)) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1ae7202711112..49ce125e739b3 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -679,7 +679,7 @@ impl GroupedHashAggregateStream { group_ordering, input_done: false, spill_state, - group_values_soft_limit: agg.limit, + group_values_soft_limit: agg.limit_options().map(|config| config.limit()), skip_aggregation_probe, reduction_factor, }) diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index a43b5cff12989..72c5d0c86745d 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -19,6 +19,7 @@ use crate::aggregates::group_values::GroupByMetrics; use crate::aggregates::topk::priority_map::PriorityMap; +#[cfg(debug_assertions)] use crate::aggregates::topk_types_supported; use crate::aggregates::{ AggregateExec, PhysicalGroupBy, aggregate_expressions, evaluate_group_by, @@ -33,6 +34,7 @@ use datafusion_common::Result; use datafusion_common::internal_datafusion_err; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::metrics::RecordOutput; use futures::stream::{Stream, StreamExt}; use log::{Level, trace}; use std::pin::Pin; @@ -66,13 +68,27 @@ impl GroupedTopKAggregateStream { let group_by_metrics = GroupByMetrics::new(&aggr.metrics, partition); let aggregate_arguments = aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?; - let (val_field, desc) = aggr - .get_minmax_desc() - .ok_or_else(|| internal_datafusion_err!("Min/max required"))?; let (expr, _) = &aggr.group_expr().expr()[0]; let kt = expr.data_type(&aggr.input().schema())?; - let vt = val_field.data_type().clone(); + + // Check if this is a MIN/MAX aggregate or a DISTINCT-like operation + let (vt, desc) = if let Some((val_field, desc)) = aggr.get_minmax_desc() { + // MIN/MAX case: use the aggregate output type + (val_field.data_type().clone(), desc) + } else { + // DISTINCT case: use the group key type and get ordering from limit_order_descending + // The ordering direction is set by the optimizer when it pushes down the limit + let desc = aggr + .limit_options() + .and_then(|config| config.descending) + .ok_or_else(|| { + internal_datafusion_err!( + "Ordering direction required for DISTINCT with limit" + ) + })?; + (kt.clone(), desc) + }; // Type validation is performed by the optimizer and can_use_topk() check. // This debug assertion documents the contract without runtime overhead in release builds. @@ -168,18 +184,21 @@ impl Stream for GroupedTopKAggregateStream { "Exactly 1 group value required" ); let group_by_values = Arc::clone(&group_by_values[0][0]); - let input_values = { - let _timer = (!self.aggregate_arguments.is_empty()).then(|| { - self.group_by_metrics.aggregate_arguments_time.timer() - }); - evaluate_many( + let input_values = if self.aggregate_arguments.is_empty() { + // DISTINCT case: use group key as both key and value + Arc::clone(&group_by_values) + } else { + // MIN/MAX case: evaluate aggregate expressions + let _timer = + self.group_by_metrics.aggregate_arguments_time.timer(); + let input_values = evaluate_many( &self.aggregate_arguments, batches.first().unwrap(), - )? + )?; + assert_eq!(input_values.len(), 1, "Exactly 1 input required"); + assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); + Arc::clone(&input_values[0][0]) }; - assert_eq!(input_values.len(), 1, "Exactly 1 input required"); - assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); - let input_values = Arc::clone(&input_values[0][0]); // iterate over each column of group_by values (*self).intern(&group_by_values, &input_values)?; @@ -192,9 +211,15 @@ impl Stream for GroupedTopKAggregateStream { } let batch = { let _timer = emitting_time.timer(); - let cols = self.priority_map.emit()?; + let mut cols = self.priority_map.emit()?; + // For DISTINCT case (no aggregate expressions), only use the group key column + // since the schema only has one field and key/value are the same + if self.aggregate_arguments.is_empty() { + cols.truncate(1); + } RecordBatch::try_new(Arc::clone(&self.schema), cols)? }; + let batch = batch.record_output(&self.baseline_metrics); trace!( "partition {} emit batch with {} rows", self.partition, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index bd7dd3a6aff3c..b87f6272b35ad 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1219,6 +1219,8 @@ message MaybePhysicalSortExprs { message AggLimit { // wrap into a message to make it optional uint64 limit = 1; + // Optional ordering direction for TopK aggregation (true = descending, false = ascending) + optional bool descending = 2; } message AggregateExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index e269606d163a3..c544d2c28b246 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -9,12 +9,18 @@ impl serde::Serialize for AggLimit { if self.limit != 0 { len += 1; } + if self.descending.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggLimit", len)?; if self.limit != 0 { #[allow(clippy::needless_borrow)] #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("limit", ToString::to_string(&self.limit).as_str())?; } + if let Some(v) = self.descending.as_ref() { + struct_ser.serialize_field("descending", v)?; + } struct_ser.end() } } @@ -26,11 +32,13 @@ impl<'de> serde::Deserialize<'de> for AggLimit { { const FIELDS: &[&str] = &[ "limit", + "descending", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Limit, + Descending, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -53,6 +61,7 @@ impl<'de> serde::Deserialize<'de> for AggLimit { { match value { "limit" => Ok(GeneratedField::Limit), + "descending" => Ok(GeneratedField::Descending), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -73,6 +82,7 @@ impl<'de> serde::Deserialize<'de> for AggLimit { V: serde::de::MapAccess<'de>, { let mut limit__ = None; + let mut descending__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Limit => { @@ -83,10 +93,17 @@ impl<'de> serde::Deserialize<'de> for AggLimit { Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } + GeneratedField::Descending => { + if descending__.is_some() { + return Err(serde::de::Error::duplicate_field("descending")); + } + descending__ = map_.next_value()?; + } } } Ok(AggLimit { limit: limit__.unwrap_or_default(), + descending: descending__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index cf343e0258d0b..fa8ea54624399 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1830,6 +1830,9 @@ pub struct AggLimit { /// wrap into a message to make it optional #[prost(uint64, tag = "1")] pub limit: u64, + /// Optional ordering direction for TopK aggregation (true = descending, false = ascending) + #[prost(bool, optional, tag = "2")] + pub descending: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct AggregateExecNode { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 0666fc2979b38..acd9af20baeaa 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -71,8 +71,8 @@ use datafusion_functions_table::generate_series::{ use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::{LexOrdering, LexRequirement, PhysicalExprRef}; -use datafusion_physical_plan::aggregates::AggregateMode; use datafusion_physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; +use datafusion_physical_plan::aggregates::{AggregateMode, LimitOptions}; use datafusion_physical_plan::analyze::AnalyzeExec; #[expect(deprecated)] use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; @@ -1105,11 +1105,6 @@ impl protobuf::PhysicalPlanNode { }) .collect::, _>>()?; - let limit = hash_agg - .limit - .as_ref() - .map(|lit_value| lit_value.limit as usize); - let agg = AggregateExec::try_new( agg_mode, PhysicalGroupBy::new(group_expr, null_expr, groups, has_grouping_set), @@ -1119,7 +1114,16 @@ impl protobuf::PhysicalPlanNode { physical_schema, )?; - let agg = agg.with_limit(limit); + let agg = if let Some(limit_proto) = &hash_agg.limit { + let limit = limit_proto.limit as usize; + let limit_options = match limit_proto.descending { + Some(descending) => LimitOptions::new_with_order(limit, descending), + None => LimitOptions::new(limit), + }; + agg.with_limit_options(Some(limit_options)) + } else { + agg + }; Ok(Arc::new(agg)) } @@ -2527,8 +2531,9 @@ impl protobuf::PhysicalPlanNode { .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) .collect::>>()?; - let limit = exec.limit().map(|value| protobuf::AggLimit { - limit: value as u64, + let limit = exec.limit_options().map(|config| protobuf::AggLimit { + limit: config.limit() as u64, + descending: config.descending(), }); Ok(protobuf::PhysicalPlanNode { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 57421fd1f25e6..bcdfe8d3b7323 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -70,7 +70,7 @@ use datafusion::physical_expr::{ LexOrdering, PhysicalSortRequirement, ScalarFunctionExpr, }; use datafusion::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -615,7 +615,7 @@ fn roundtrip_aggregate_with_limit() -> Result<()> { Arc::new(EmptyExec::new(schema.clone())), schema, )?; - let agg = agg.with_limit(Some(12)); + let agg = agg.with_limit_options(Some(LimitOptions::new_with_order(12, false))); roundtrip_test(Arc::new(agg)) } diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt b/datafusion/sqllogictest/test_files/aggregates_topk.slt index 05f3e02bbc1b3..19ead8965ed01 100644 --- a/datafusion/sqllogictest/test_files/aggregates_topk.slt +++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt @@ -344,5 +344,123 @@ physical_plan 06)----------DataSourceExec: partitions=1, partition_sizes=[1] +## Test GROUP BY with ORDER BY on the same column (no aggregate functions) +statement ok +CREATE TABLE ids(id int, value int) AS VALUES +(1, 10), +(2, 20), +(3, 30), +(4, 40), +(1, 50), +(2, 60), +(5, 70); + +query TT +explain select id from ids group by id order by id desc limit 3; +---- +logical_plan +01)Sort: ids.id DESC NULLS FIRST, fetch=3 +02)--Aggregate: groupBy=[[ids.id]], aggr=[[]] +03)----TableScan: ids projection=[id] +physical_plan +01)SortPreservingMergeExec: [id@0 DESC], fetch=3 +02)--SortExec: TopK(fetch=3), expr=[id@0 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[], lim=[3] +04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], lim=[3] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query I +select id from ids group by id order by id desc limit 3; +---- +5 +4 +3 + +query TT +explain select id from ids group by id order by id asc limit 2; +---- +logical_plan +01)Sort: ids.id ASC NULLS LAST, fetch=2 +02)--Aggregate: groupBy=[[ids.id]], aggr=[[]] +03)----TableScan: ids projection=[id] +physical_plan +01)SortPreservingMergeExec: [id@0 ASC NULLS LAST], fetch=2 +02)--SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[], lim=[2] +04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[], lim=[2] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query I +select id from ids group by id order by id asc limit 2; +---- +1 +2 + +# Test with larger limit than distinct values +query I +select id from ids group by id order by id desc limit 100; +---- +5 +4 +3 +2 +1 + +# Test with bigint group by +statement ok +CREATE TABLE values_table (value INT, category BIGINT) AS VALUES +(10, 100), +(20, 200), +(30, 300), +(40, 400), +(50, 500), +(20, 200), +(10, 100), +(40, 400); + +query TT +explain select category from values_table group by category order by category desc limit 3; +---- +logical_plan +01)Sort: values_table.category DESC NULLS FIRST, fetch=3 +02)--Aggregate: groupBy=[[values_table.category]], aggr=[[]] +03)----TableScan: values_table projection=[category] +physical_plan +01)SortPreservingMergeExec: [category@0 DESC], fetch=3 +02)--SortExec: TopK(fetch=3), expr=[category@0 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[category@0 as category], aggr=[], lim=[3] +04)------RepartitionExec: partitioning=Hash([category@0], 4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[category@0 as category], aggr=[], lim=[3] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query I +select category from values_table group by category order by category desc limit 3; +---- +500 +400 +300 + +# Test with integer group by +query I +select value from values_table group by value order by value asc limit 3; +---- +10 +20 +30 + +# Test DISTINCT semantics are preserved +query I +select count(*) from (select category from values_table group by category order by category desc limit 3); +---- +3 + +statement ok +drop table values_table; + +statement ok +drop table ids; + statement ok drop table traces; diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index cd1ed2bc0caca..db4ec83f10129 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4329,9 +4329,9 @@ physical_plan 01)SortPreservingMergeExec: [months@0 DESC], fetch=5 02)--SortExec: TopK(fetch=5), expr=[months@0 DESC], preserve_partitioning=[true] 03)----ProjectionExec: expr=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as months] -04)------AggregateExec: mode=FinalPartitioned, gby=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +04)------AggregateExec: mode=FinalPartitioned, gby=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[], lim=[5] 05)--------RepartitionExec: partitioning=Hash([date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0], 8), input_partitions=8 -06)----------AggregateExec: mode=Partial, gby=[date_part(MONTH, ts@0) as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +06)----------AggregateExec: mode=Partial, gby=[date_part(MONTH, ts@0) as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[], lim=[5] 07)------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1, maintains_sort_order=true 08)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/timestamps.csv]]}, projection=[ts], output_ordering=[ts@0 DESC], file_type=csv, has_header=false