Skip to content

Commit 07e63ed

Browse files
authored
Fix TopK aggregation for UTF-8/Utf8View group keys and add safe fallback for unsupported string aggregates (#19285)
## Which issue does this PR close? * Closes #19219. ## Rationale for this change A `GROUP BY ... ORDER BY <aggregate> ... LIMIT` query can trigger DataFusion’s TopK aggregation optimization. In affected releases, queries grouping by text columns—especially `Utf8View` produced via SQL `varchar` mappings / `arrow_cast`—could fail at execution time with an error such as `Can't group type: Utf8View`. This happens because the optimizer may select the TopK aggregation path even when the underlying TopK data structures (heap/hash table) do not fully support the specific key/value Arrow types involved. Disabling `datafusion.optimizer.enable_topk_aggregation` is a workaround, but it forces users to trade correctness for performance. This PR makes TopK type support explicit and consistent across the optimizer and execution, adds support for UTF-8 string value heaps, and ensures unsupported key/value combinations fall back to the standard aggregation implementation rather than panicking. ## What changes are included in this PR? * **Centralized TopK type validation** * Introduced `topk_types_supported(key_type, value_type)` (in `physical-plan/src/aggregates/mod.rs`) to validate both grouping key and min/max value types. * Optimizer now uses this shared check rather than duplicating partial type logic. * **Safer AggregateExec cloning for limit pushdown** * Added `AggregateExec::with_new_limit` to clone an aggregate exec while overriding only the TopK `limit` hint, avoiding manual reconstruction and ensuring plan properties/fields remain consistent. * **TopK hash table improvements + helper functions** * Added `is_supported_hash_key_type` helper for grouping key compatibility checks. * Refactored string key extraction to a single helper function. * Added `find_or_insert` entry API to avoid double lookups and unify insertion behavior. * **TopK heap support for string aggregate values** * Added `StringHeap` implementation supporting `Utf8`, `LargeUtf8`, and `Utf8View` aggregate values using lexicographic ordering. * Added `is_supported_heap_type` helper for aggregate value compatibility. * Updated `new_heap` to create `StringHeap` for supported string types and return a clearer error message for unsupported types. * **Debug contract in TopK stream** * Added a debug assertion in `GroupedTopKAggregateStream` documenting that type validation should have already happened (optimizer + can_use_topk), without affecting release builds. ## Are these changes tested? Yes. * Added a new physical optimizer test covering UTF-8 grouping with: 1. **Supported** numeric `max/min` value (TopK should be used and results correct) 2. **Unsupported** string `max/min` value (must fall back to standard aggregation and not use `GroupedTopKAggregateStream`) * Added unit tests in `PriorityMap` to validate lexicographic `min/max` tracking for: * `Utf8` * `LargeUtf8` * `Utf8View` * Added SQLLogicTest coverage (`aggregates_topk.slt`) for: * `varchar` tables * `Utf8View` via `arrow_cast` * `EXPLAIN` verification that TopK limit propagation is applied and plans remain stable * Regression case for `max(trace_id)` with `ORDER BY ... LIMIT` ## Are there any user-facing changes? Yes (bug fix). * Queries that group by text columns (including `Utf8View`) and use `ORDER BY <aggregate> ... LIMIT` should no longer error. * TopK aggregation now supports UTF-8 string aggregate values for min/max (lexicographic ordering) where applicable. * For unsupported type combinations, DataFusion will fall back gracefully to the standard aggregation path instead of panicking. No breaking public API changes are intended. The only new public helper APIs are internal to the physical plan modules. ## LLM-generated code disclosure This PR includes LLM-generated code and comments. All LLM-generated content has been manually reviewed and tested.
1 parent b7091c0 commit 07e63ed

File tree

9 files changed

+698
-106
lines changed

9 files changed

+698
-106
lines changed

datafusion/core/benches/topk_aggregate.rs

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ use std::hint::black_box;
2828
use std::sync::Arc;
2929
use tokio::runtime::Runtime;
3030

31+
const LIMIT: usize = 10;
32+
3133
async fn create_context(
3234
partition_cnt: i32,
3335
sample_cnt: i32,
@@ -52,6 +54,11 @@ fn run(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool, asc: boo
5254
black_box(rt.block_on(async { aggregate(ctx, limit, use_topk, asc).await })).unwrap();
5355
}
5456

57+
fn run_string(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool) {
58+
black_box(rt.block_on(async { aggregate_string(ctx, limit, use_topk).await }))
59+
.unwrap();
60+
}
61+
5562
async fn aggregate(
5663
ctx: SessionContext,
5764
limit: usize,
@@ -72,7 +79,7 @@ async fn aggregate(
7279
let batches = collect(plan, ctx.task_ctx()).await?;
7380
assert_eq!(batches.len(), 1);
7481
let batch = batches.first().unwrap();
75-
assert_eq!(batch.num_rows(), 10);
82+
assert_eq!(batch.num_rows(), LIMIT);
7683

7784
let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase();
7885
let expected_asc = r#"
@@ -99,9 +106,36 @@ async fn aggregate(
99106
Ok(())
100107
}
101108

109+
/// Benchmark for string aggregate functions with topk optimization.
110+
/// This tests grouping by a numeric column (timestamp_ms) and aggregating
111+
/// a string column (trace_id) with Utf8 or Utf8View data types.
112+
async fn aggregate_string(
113+
ctx: SessionContext,
114+
limit: usize,
115+
use_topk: bool,
116+
) -> Result<()> {
117+
let sql = format!(
118+
"select max(trace_id) from traces group by timestamp_ms order by max(trace_id) desc limit {limit};"
119+
);
120+
let df = ctx.sql(sql.as_str()).await?;
121+
let plan = df.create_physical_plan().await?;
122+
let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string();
123+
assert_eq!(
124+
actual_phys_plan.contains(&format!("lim=[{limit}]")),
125+
use_topk
126+
);
127+
128+
let batches = collect(plan, ctx.task_ctx()).await?;
129+
assert_eq!(batches.len(), 1);
130+
let batch = batches.first().unwrap();
131+
assert_eq!(batch.num_rows(), LIMIT);
132+
133+
Ok(())
134+
}
135+
102136
fn criterion_benchmark(c: &mut Criterion) {
103137
let rt = Runtime::new().unwrap();
104-
let limit = 10;
138+
let limit = LIMIT;
105139
let partitions = 10;
106140
let samples = 1_000_000;
107141

@@ -170,6 +204,55 @@ fn criterion_benchmark(c: &mut Criterion) {
170204
.as_str(),
171205
|b| b.iter(|| run(&rt, ctx.clone(), limit, true, true)),
172206
);
207+
208+
// String aggregate benchmarks - grouping by timestamp, aggregating string column
209+
let ctx = rt
210+
.block_on(create_context(partitions, samples, false, true, false))
211+
.unwrap();
212+
c.bench_function(
213+
format!(
214+
"top k={limit} string aggregate {} time-series rows [Utf8]",
215+
partitions * samples
216+
)
217+
.as_str(),
218+
|b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)),
219+
);
220+
221+
let ctx = rt
222+
.block_on(create_context(partitions, samples, true, true, false))
223+
.unwrap();
224+
c.bench_function(
225+
format!(
226+
"top k={limit} string aggregate {} worst-case rows [Utf8]",
227+
partitions * samples
228+
)
229+
.as_str(),
230+
|b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)),
231+
);
232+
233+
let ctx = rt
234+
.block_on(create_context(partitions, samples, false, true, true))
235+
.unwrap();
236+
c.bench_function(
237+
format!(
238+
"top k={limit} string aggregate {} time-series rows [Utf8View]",
239+
partitions * samples
240+
)
241+
.as_str(),
242+
|b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)),
243+
);
244+
245+
let ctx = rt
246+
.block_on(create_context(partitions, samples, true, true, true))
247+
.unwrap();
248+
c.bench_function(
249+
format!(
250+
"top k={limit} string aggregate {} worst-case rows [Utf8View]",
251+
partitions * samples
252+
)
253+
.as_str(),
254+
|b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)),
255+
);
173256
}
174257

175258
criterion_group!(benches, criterion_benchmark);

datafusion/core/tests/physical_optimizer/aggregate_statistics.rs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,15 @@ use std::sync::Arc;
2020
use crate::physical_optimizer::test_utils::TestAggregate;
2121

2222
use arrow::array::Int32Array;
23+
use arrow::array::{Int64Array, StringArray};
2324
use arrow::datatypes::{DataType, Field, Schema};
2425
use arrow::record_batch::RecordBatch;
26+
use datafusion::datasource::memory::MemTable;
2527
use datafusion::datasource::memory::MemorySourceConfig;
2628
use datafusion::datasource::source::DataSourceExec;
29+
use datafusion::prelude::{SessionConfig, SessionContext};
2730
use datafusion_common::Result;
31+
use datafusion_common::assert_batches_eq;
2832
use datafusion_common::cast::as_int64_array;
2933
use datafusion_common::config::ConfigOptions;
3034
use datafusion_execution::TaskContext;
@@ -38,6 +42,7 @@ use datafusion_physical_plan::aggregates::AggregateMode;
3842
use datafusion_physical_plan::aggregates::PhysicalGroupBy;
3943
use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec;
4044
use datafusion_physical_plan::common;
45+
use datafusion_physical_plan::displayable;
4146
use datafusion_physical_plan::filter::FilterExec;
4247
use datafusion_physical_plan::projection::ProjectionExec;
4348

@@ -316,3 +321,84 @@ async fn test_count_with_nulls_inexact_stat() -> Result<()> {
316321

317322
Ok(())
318323
}
324+
325+
/// Tests that TopK aggregation correctly handles UTF-8 (string) types in both grouping keys and aggregate values.
326+
///
327+
/// The TopK optimization is designed to efficiently handle `GROUP BY ... ORDER BY aggregate LIMIT n` queries
328+
/// by maintaining only the top K groups during aggregation. However, not all type combinations are supported.
329+
///
330+
/// This test verifies two scenarios:
331+
/// 1. **Supported case**: UTF-8 grouping key with numeric aggregate (max/min) - should use TopK optimization
332+
/// 2. **Unsupported case**: UTF-8 grouping key with UTF-8 aggregate value - must gracefully fall back to
333+
/// standard aggregation without panicking
334+
///
335+
/// The fallback behavior is critical because attempting to use TopK with unsupported types could cause
336+
/// runtime panics. This test ensures the optimizer correctly detects incompatible types and chooses
337+
/// the appropriate execution path.
338+
#[tokio::test]
339+
async fn utf8_grouping_min_max_limit_fallbacks() -> Result<()> {
340+
let mut config = SessionConfig::new();
341+
config.options_mut().optimizer.enable_topk_aggregation = true;
342+
let ctx = SessionContext::new_with_config(config);
343+
344+
let batch = RecordBatch::try_new(
345+
Arc::new(Schema::new(vec![
346+
Field::new("g", DataType::Utf8, false),
347+
Field::new("val_str", DataType::Utf8, false),
348+
Field::new("val_num", DataType::Int64, false),
349+
])),
350+
vec![
351+
Arc::new(StringArray::from(vec!["a", "b", "a"])),
352+
Arc::new(StringArray::from(vec!["alpha", "bravo", "charlie"])),
353+
Arc::new(Int64Array::from(vec![1, 2, 3])),
354+
],
355+
)?;
356+
let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
357+
ctx.register_table("t", Arc::new(table))?;
358+
359+
// Supported path: numeric min/max with UTF-8 grouping should still use TopK aggregation
360+
// and return correct results.
361+
let supported_df = ctx
362+
.sql("SELECT g, max(val_num) AS m FROM t GROUP BY g ORDER BY m DESC LIMIT 1")
363+
.await?;
364+
let supported_batches = supported_df.collect().await?;
365+
assert_batches_eq!(
366+
&[
367+
"+---+---+",
368+
"| g | m |",
369+
"+---+---+",
370+
"| a | 3 |",
371+
"+---+---+"
372+
],
373+
&supported_batches
374+
);
375+
376+
// Unsupported TopK value type: string min/max should fall back without panicking.
377+
let unsupported_df = ctx
378+
.sql("SELECT g, max(val_str) AS s FROM t GROUP BY g ORDER BY s DESC LIMIT 1")
379+
.await?;
380+
let unsupported_plan = unsupported_df.clone().create_physical_plan().await?;
381+
let unsupported_batches = unsupported_df.collect().await?;
382+
383+
// Ensure the plan avoided the TopK-specific stream implementation.
384+
let plan_display = displayable(unsupported_plan.as_ref())
385+
.indent(true)
386+
.to_string();
387+
assert!(
388+
!plan_display.contains("GroupedTopKAggregateStream"),
389+
"Unsupported UTF-8 aggregate value should not use TopK: {plan_display}"
390+
);
391+
392+
assert_batches_eq!(
393+
&[
394+
"+---+---------+",
395+
"| g | s |",
396+
"+---+---------+",
397+
"| a | charlie |",
398+
"+---+---------+"
399+
],
400+
&unsupported_batches
401+
);
402+
403+
Ok(())
404+
}

datafusion/physical-optimizer/src/topk_aggregation.rs

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@
2020
use std::sync::Arc;
2121

2222
use crate::PhysicalOptimizerRule;
23-
use arrow::datatypes::DataType;
2423
use datafusion_common::Result;
2524
use datafusion_common::config::ConfigOptions;
2625
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
2726
use datafusion_physical_expr::expressions::Column;
2827
use datafusion_physical_plan::ExecutionPlan;
29-
use datafusion_physical_plan::aggregates::AggregateExec;
28+
use datafusion_physical_plan::aggregates::{AggregateExec, topk_types_supported};
3029
use datafusion_physical_plan::execution_plan::CardinalityEffect;
3130
use datafusion_physical_plan::projection::ProjectionExec;
3231
use datafusion_physical_plan::sorts::sort::SortExec;
@@ -55,11 +54,8 @@ impl TopKAggregation {
5554
}
5655
let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?;
5756
let kt = group_key.0.data_type(&aggr.input().schema()).ok()?;
58-
if !kt.is_primitive()
59-
&& kt != DataType::Utf8
60-
&& kt != DataType::Utf8View
61-
&& kt != DataType::LargeUtf8
62-
{
57+
let vt = field.data_type();
58+
if !topk_types_supported(&kt, vt) {
6359
return None;
6460
}
6561
if aggr.filter_expr().iter().any(|e| e.is_some()) {
@@ -72,16 +68,7 @@ impl TopKAggregation {
7268
}
7369

7470
// We found what we want: clone, copy the limit down, and return modified node
75-
let new_aggr = AggregateExec::try_new(
76-
*aggr.mode(),
77-
aggr.group_expr().clone(),
78-
aggr.aggr_expr().to_vec(),
79-
aggr.filter_expr().to_vec(),
80-
Arc::clone(aggr.input()),
81-
aggr.input_schema(),
82-
)
83-
.expect("Unable to copy Aggregate!")
84-
.with_limit(Some(limit));
71+
let new_aggr = aggr.with_new_limit(Some(limit));
8572
Some(Arc::new(new_aggr))
8673
}
8774

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use parking_lot::Mutex;
4141
use std::collections::HashSet;
4242

4343
use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array};
44-
use arrow::datatypes::{Field, Schema, SchemaRef};
44+
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
4545
use arrow::record_batch::RecordBatch;
4646
use arrow_schema::FieldRef;
4747
use datafusion_common::stats::Precision;
@@ -64,6 +64,8 @@ use datafusion_physical_expr_common::sort_expr::{
6464
use datafusion_expr::utils::AggregateOrderSensitivity;
6565
use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
6666
use itertools::Itertools;
67+
use topk::hash_table::is_supported_hash_key_type;
68+
use topk::heap::is_supported_heap_type;
6769

6870
pub mod group_values;
6971
mod no_grouping;
@@ -72,6 +74,17 @@ mod row_hash;
7274
mod topk;
7375
mod topk_stream;
7476

77+
/// Returns true if TopK aggregation data structures support the provided key and value types.
78+
///
79+
/// This function checks whether both the key type (used for grouping) and value type
80+
/// (used in min/max aggregation) can be handled by the TopK aggregation heap and hash table.
81+
/// Supported types include Arrow primitives (integers, floats, decimals, intervals) and
82+
/// UTF-8 strings (`Utf8`, `LargeUtf8`, `Utf8View`).
83+
/// ```text
84+
pub fn topk_types_supported(key_type: &DataType, value_type: &DataType) -> bool {
85+
is_supported_hash_key_type(key_type) && is_supported_heap_type(value_type)
86+
}
87+
7588
/// Hard-coded seed for aggregations to ensure hash values differ from `RepartitionExec`, avoiding collisions.
7689
const AGGREGATION_HASH_SEED: ahash::RandomState =
7790
ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64);
@@ -553,6 +566,26 @@ impl AggregateExec {
553566
}
554567
}
555568

569+
/// Clone this exec, overriding only the limit hint.
570+
pub fn with_new_limit(&self, limit: Option<usize>) -> Self {
571+
Self {
572+
limit,
573+
// clone the rest of the fields
574+
required_input_ordering: self.required_input_ordering.clone(),
575+
metrics: ExecutionPlanMetricsSet::new(),
576+
input_order_mode: self.input_order_mode.clone(),
577+
cache: self.cache.clone(),
578+
mode: self.mode,
579+
group_by: self.group_by.clone(),
580+
aggr_expr: self.aggr_expr.clone(),
581+
filter_expr: self.filter_expr.clone(),
582+
input: Arc::clone(&self.input),
583+
schema: Arc::clone(&self.schema),
584+
input_schema: Arc::clone(&self.input_schema),
585+
dynamic_filter: self.dynamic_filter.clone(),
586+
}
587+
}
588+
556589
pub fn cache(&self) -> &PlanProperties {
557590
&self.cache
558591
}

0 commit comments

Comments
 (0)