Skip to content

Commit e6ddb48

Browse files
authored
Fix Partial AggregateExec correctness issue dropping rows (#18712)
## Which issue does this PR close? More detail is in the issue. <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #18701 ## Rationale for this change This is a pretty major correctness issue. <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? Fixes issue and reorders skip aggregate and emit early within partial aggregate execution <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes, the unit test that's added here previously failed before this change. <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent a7b2e85 commit e6ddb48

File tree

1 file changed

+161
-10
lines changed

1 file changed

+161
-10
lines changed

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 161 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,6 @@ impl Stream for GroupedHashAggregateStream {
687687
// Do the grouping
688688
self.group_aggregate_batch(batch)?;
689689

690-
self.update_skip_aggregation_probe(input_rows);
691-
692690
// If we can begin emitting rows, do so,
693691
// otherwise keep consuming input
694692
assert!(!self.input_done);
@@ -712,9 +710,22 @@ impl Stream for GroupedHashAggregateStream {
712710
break 'reading_input;
713711
}
714712

715-
self.emit_early_if_necessary()?;
713+
// Check if we should switch to skip aggregation mode
714+
// It's important that we do this before we early emit since we've
715+
// already updated the probe.
716+
self.update_skip_aggregation_probe(input_rows);
717+
if let Some(new_state) = self.switch_to_skip_aggregation()? {
718+
timer.done();
719+
self.exec_state = new_state;
720+
break 'reading_input;
721+
}
716722

717-
self.switch_to_skip_aggregation()?;
723+
// Check if we need to emit early due to memory pressure
724+
if let Some(new_state) = self.emit_early_if_necessary()? {
725+
timer.done();
726+
self.exec_state = new_state;
727+
break 'reading_input;
728+
}
718729

719730
timer.done();
720731
}
@@ -788,6 +799,15 @@ impl Stream for GroupedHashAggregateStream {
788799
}
789800
None => {
790801
// inner is done, switching to `Done` state
802+
// Sanity check: when switching from SkippingAggregation to Done,
803+
// all groups should have already been emitted
804+
if !self.group_values.is_empty() {
805+
return Poll::Ready(Some(internal_err!(
806+
"Switching from SkippingAggregation to Done with {} groups still in hash table. \
807+
This is a bug - all groups should have been emitted before skip aggregation started.",
808+
self.group_values.len()
809+
)));
810+
}
791811
self.exec_state = ExecutionState::Done;
792812
}
793813
}
@@ -835,6 +855,13 @@ impl Stream for GroupedHashAggregateStream {
835855
}
836856

837857
ExecutionState::Done => {
858+
// Sanity check: all groups should have been emitted by now
859+
if !self.group_values.is_empty() {
860+
return Poll::Ready(Some(internal_err!(
861+
"AggregateStream was in Done state with {} groups left in hash table. \
862+
This is a bug - all groups should have been emitted before entering Done state.",
863+
self.group_values.len())));
864+
}
838865
// release the memory reservation since sending back output batch itself needs
839866
// some memory reservation, so make some room for it.
840867
self.clear_all();
@@ -1100,18 +1127,20 @@ impl GroupedHashAggregateStream {
11001127
/// Emit if the used memory exceeds the target for partial aggregation.
11011128
/// Currently only [`GroupOrdering::None`] is supported for early emitting.
11021129
/// TODO: support group_ordering for early emitting
1103-
fn emit_early_if_necessary(&mut self) -> Result<()> {
1130+
///
1131+
/// Returns `Some(ExecutionState)` if the state should be changed, None otherwise.
1132+
fn emit_early_if_necessary(&mut self) -> Result<Option<ExecutionState>> {
11041133
if self.group_values.len() >= self.batch_size
11051134
&& matches!(self.group_ordering, GroupOrdering::None)
11061135
&& self.update_memory_reservation().is_err()
11071136
{
11081137
assert_eq!(self.mode, AggregateMode::Partial);
11091138
let n = self.group_values.len() / self.batch_size * self.batch_size;
11101139
if let Some(batch) = self.emit(EmitTo::First(n), false)? {
1111-
self.exec_state = ExecutionState::ProducingOutput(batch);
1140+
return Ok(Some(ExecutionState::ProducingOutput(batch)));
11121141
};
11131142
}
1114-
Ok(())
1143+
Ok(None)
11151144
}
11161145

11171146
/// At this point, all the inputs are read and there are some spills.
@@ -1194,16 +1223,18 @@ impl GroupedHashAggregateStream {
11941223
/// skipped, forces stream to produce currently accumulated output.
11951224
///
11961225
/// Notice: It should only be called in Partial aggregation
1197-
fn switch_to_skip_aggregation(&mut self) -> Result<()> {
1226+
///
1227+
/// Returns `Some(ExecutionState)` if the state should be changed, None otherwise.
1228+
fn switch_to_skip_aggregation(&mut self) -> Result<Option<ExecutionState>> {
11981229
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
11991230
if probe.should_skip() {
12001231
if let Some(batch) = self.emit(EmitTo::All, false)? {
1201-
self.exec_state = ExecutionState::ProducingOutput(batch);
1232+
return Ok(Some(ExecutionState::ProducingOutput(batch)));
12021233
};
12031234
}
12041235
}
12051236

1206-
Ok(())
1237+
Ok(None)
12071238
}
12081239

12091240
/// Returns true if the aggregation probe indicates that aggregation
@@ -1245,3 +1276,123 @@ impl GroupedHashAggregateStream {
12451276
Ok(states_batch)
12461277
}
12471278
}
1279+
1280+
#[cfg(test)]
1281+
mod tests {
1282+
use super::*;
1283+
use crate::test::TestMemoryExec;
1284+
use arrow::array::{Int32Array, Int64Array};
1285+
use arrow::datatypes::{DataType, Field, Schema};
1286+
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
1287+
use datafusion_execution::TaskContext;
1288+
use datafusion_functions_aggregate::count::count_udaf;
1289+
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
1290+
use datafusion_physical_expr::expressions::col;
1291+
use std::sync::Arc;
1292+
1293+
#[tokio::test]
1294+
async fn test_double_emission_race_condition_bug() -> Result<()> {
1295+
// Fix for https://github.com/apache/datafusion/issues/18701
1296+
// This test specifically proves that we have fixed double emission race condition
1297+
// where emit_early_if_necessary() and switch_to_skip_aggregation()
1298+
// both emit in the same loop iteration, causing data loss
1299+
1300+
let schema = Arc::new(Schema::new(vec![
1301+
Field::new("group_col", DataType::Int32, false),
1302+
Field::new("value_col", DataType::Int64, false),
1303+
]));
1304+
1305+
// Create data that will trigger BOTH conditions in the same iteration:
1306+
// 1. More groups than batch_size (triggers early emission when memory pressure hits)
1307+
// 2. High cardinality ratio (triggers skip aggregation)
1308+
let batch_size = 1024; // We'll set this in session config
1309+
let num_groups = batch_size + 100; // Slightly more than batch_size (1124 groups)
1310+
1311+
// Create exactly 1 row per group = 100% cardinality ratio
1312+
let group_ids: Vec<i32> = (0..num_groups as i32).collect();
1313+
let values: Vec<i64> = vec![1; num_groups];
1314+
1315+
let batch = RecordBatch::try_new(
1316+
Arc::clone(&schema),
1317+
vec![
1318+
Arc::new(Int32Array::from(group_ids)),
1319+
Arc::new(Int64Array::from(values)),
1320+
],
1321+
)?;
1322+
1323+
let input_partitions = vec![vec![batch]];
1324+
1325+
// Create constrained memory to trigger early emission but not completely fail
1326+
let runtime = RuntimeEnvBuilder::default()
1327+
.with_memory_limit(1024, 1.0) // small enough to start but will trigger pressure
1328+
.build_arc()?;
1329+
1330+
let mut task_ctx = TaskContext::default().with_runtime(runtime);
1331+
1332+
// Configure to trigger BOTH conditions:
1333+
// 1. Low probe threshold (triggers skip probe after few rows)
1334+
// 2. Low ratio threshold (triggers skip aggregation immediately)
1335+
// 3. Set batch_size to 1024 so our 1124 groups will trigger early emission
1336+
// This creates the race condition where both emit paths are triggered
1337+
let mut session_config = task_ctx.session_config().clone();
1338+
session_config = session_config.set(
1339+
"datafusion.execution.batch_size",
1340+
&datafusion_common::ScalarValue::UInt64(Some(1024)),
1341+
);
1342+
session_config = session_config.set(
1343+
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
1344+
&datafusion_common::ScalarValue::UInt64(Some(50)),
1345+
);
1346+
session_config = session_config.set(
1347+
"datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
1348+
&datafusion_common::ScalarValue::Float64(Some(0.8)),
1349+
);
1350+
task_ctx = task_ctx.with_session_config(session_config);
1351+
let task_ctx = Arc::new(task_ctx);
1352+
1353+
// Create aggregate: COUNT(*) GROUP BY group_col
1354+
let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())];
1355+
let aggr_expr = vec![Arc::new(
1356+
AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?])
1357+
.schema(Arc::clone(&schema))
1358+
.alias("count_value")
1359+
.build()?,
1360+
)];
1361+
1362+
let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
1363+
let exec = Arc::new(TestMemoryExec::update_cache(Arc::new(exec)));
1364+
1365+
// Use Partial mode where the race condition occurs
1366+
let aggregate_exec = AggregateExec::try_new(
1367+
AggregateMode::Partial,
1368+
PhysicalGroupBy::new_single(group_expr),
1369+
aggr_expr,
1370+
vec![None],
1371+
exec,
1372+
Arc::clone(&schema),
1373+
)?;
1374+
1375+
// Execute and collect results
1376+
let mut stream =
1377+
GroupedHashAggregateStream::new(&aggregate_exec, Arc::clone(&task_ctx), 0)?;
1378+
let mut results = Vec::new();
1379+
1380+
while let Some(result) = stream.next().await {
1381+
let batch = result?;
1382+
results.push(batch);
1383+
}
1384+
1385+
// Count total groups emitted
1386+
let mut total_output_groups = 0;
1387+
for batch in &results {
1388+
total_output_groups += batch.num_rows();
1389+
}
1390+
1391+
assert_eq!(
1392+
total_output_groups, num_groups,
1393+
"Unexpected number of groups",
1394+
);
1395+
1396+
Ok(())
1397+
}
1398+
}

0 commit comments

Comments
 (0)