@@ -687,8 +687,6 @@ impl Stream for GroupedHashAggregateStream {
687687 // Do the grouping
688688 self . group_aggregate_batch ( batch) ?;
689689
690- self . update_skip_aggregation_probe ( input_rows) ;
691-
692690 // If we can begin emitting rows, do so,
693691 // otherwise keep consuming input
694692 assert ! ( !self . input_done) ;
@@ -712,9 +710,22 @@ impl Stream for GroupedHashAggregateStream {
712710 break ' reading_input;
713711 }
714712
715- self . emit_early_if_necessary ( ) ?;
713+ // Check if we should switch to skip aggregation mode
714+ // It's important that we do this before we early emit since we've
715+ // already updated the probe.
716+ self . update_skip_aggregation_probe ( input_rows) ;
717+ if let Some ( new_state) = self . switch_to_skip_aggregation ( ) ? {
718+ timer. done ( ) ;
719+ self . exec_state = new_state;
720+ break ' reading_input;
721+ }
716722
717- self . switch_to_skip_aggregation ( ) ?;
723+ // Check if we need to emit early due to memory pressure
724+ if let Some ( new_state) = self . emit_early_if_necessary ( ) ? {
725+ timer. done ( ) ;
726+ self . exec_state = new_state;
727+ break ' reading_input;
728+ }
718729
719730 timer. done ( ) ;
720731 }
@@ -788,6 +799,15 @@ impl Stream for GroupedHashAggregateStream {
788799 }
789800 None => {
790801 // inner is done, switching to `Done` state
802+ // Sanity check: when switching from SkippingAggregation to Done,
803+ // all groups should have already been emitted
804+ if !self . group_values . is_empty ( ) {
805+ return Poll :: Ready ( Some ( internal_err ! (
806+ "Switching from SkippingAggregation to Done with {} groups still in hash table. \
807+ This is a bug - all groups should have been emitted before skip aggregation started.",
808+ self . group_values. len( )
809+ ) ) ) ;
810+ }
791811 self . exec_state = ExecutionState :: Done ;
792812 }
793813 }
@@ -835,6 +855,13 @@ impl Stream for GroupedHashAggregateStream {
835855 }
836856
837857 ExecutionState :: Done => {
858+ // Sanity check: all groups should have been emitted by now
859+ if !self . group_values . is_empty ( ) {
860+ return Poll :: Ready ( Some ( internal_err ! (
861+ "AggregateStream was in Done state with {} groups left in hash table. \
862+ This is a bug - all groups should have been emitted before entering Done state.",
863+ self . group_values. len( ) ) ) ) ;
864+ }
838865 // release the memory reservation since sending back output batch itself needs
839866 // some memory reservation, so make some room for it.
840867 self . clear_all ( ) ;
@@ -1100,18 +1127,20 @@ impl GroupedHashAggregateStream {
11001127 /// Emit if the used memory exceeds the target for partial aggregation.
11011128 /// Currently only [`GroupOrdering::None`] is supported for early emitting.
11021129 /// TODO: support group_ordering for early emitting
1103- fn emit_early_if_necessary ( & mut self ) -> Result < ( ) > {
1130+ ///
1131+ /// Returns `Some(ExecutionState)` if the state should be changed, None otherwise.
1132+ fn emit_early_if_necessary ( & mut self ) -> Result < Option < ExecutionState > > {
11041133 if self . group_values . len ( ) >= self . batch_size
11051134 && matches ! ( self . group_ordering, GroupOrdering :: None )
11061135 && self . update_memory_reservation ( ) . is_err ( )
11071136 {
11081137 assert_eq ! ( self . mode, AggregateMode :: Partial ) ;
11091138 let n = self . group_values . len ( ) / self . batch_size * self . batch_size ;
11101139 if let Some ( batch) = self . emit ( EmitTo :: First ( n) , false ) ? {
1111- self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1140+ return Ok ( Some ( ExecutionState :: ProducingOutput ( batch) ) ) ;
11121141 } ;
11131142 }
1114- Ok ( ( ) )
1143+ Ok ( None )
11151144 }
11161145
11171146 /// At this point, all the inputs are read and there are some spills.
@@ -1194,16 +1223,18 @@ impl GroupedHashAggregateStream {
11941223 /// skipped, forces stream to produce currently accumulated output.
11951224 ///
11961225 /// Notice: It should only be called in Partial aggregation
1197- fn switch_to_skip_aggregation ( & mut self ) -> Result < ( ) > {
1226+ ///
1227+ /// Returns `Some(ExecutionState)` if the state should be changed, None otherwise.
1228+ fn switch_to_skip_aggregation ( & mut self ) -> Result < Option < ExecutionState > > {
11981229 if let Some ( probe) = self . skip_aggregation_probe . as_mut ( ) {
11991230 if probe. should_skip ( ) {
12001231 if let Some ( batch) = self . emit ( EmitTo :: All , false ) ? {
1201- self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1232+ return Ok ( Some ( ExecutionState :: ProducingOutput ( batch) ) ) ;
12021233 } ;
12031234 }
12041235 }
12051236
1206- Ok ( ( ) )
1237+ Ok ( None )
12071238 }
12081239
12091240 /// Returns true if the aggregation probe indicates that aggregation
@@ -1245,3 +1276,123 @@ impl GroupedHashAggregateStream {
12451276 Ok ( states_batch)
12461277 }
12471278}
1279+
1280+ #[ cfg( test) ]
1281+ mod tests {
1282+ use super :: * ;
1283+ use crate :: test:: TestMemoryExec ;
1284+ use arrow:: array:: { Int32Array , Int64Array } ;
1285+ use arrow:: datatypes:: { DataType , Field , Schema } ;
1286+ use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
1287+ use datafusion_execution:: TaskContext ;
1288+ use datafusion_functions_aggregate:: count:: count_udaf;
1289+ use datafusion_physical_expr:: aggregate:: AggregateExprBuilder ;
1290+ use datafusion_physical_expr:: expressions:: col;
1291+ use std:: sync:: Arc ;
1292+
1293+ #[ tokio:: test]
1294+ async fn test_double_emission_race_condition_bug ( ) -> Result < ( ) > {
1295+ // Fix for https://github.com/apache/datafusion/issues/18701
1296+ // This test specifically proves that we have fixed double emission race condition
1297+ // where emit_early_if_necessary() and switch_to_skip_aggregation()
1298+ // both emit in the same loop iteration, causing data loss
1299+
1300+ let schema = Arc :: new ( Schema :: new ( vec ! [
1301+ Field :: new( "group_col" , DataType :: Int32 , false ) ,
1302+ Field :: new( "value_col" , DataType :: Int64 , false ) ,
1303+ ] ) ) ;
1304+
1305+ // Create data that will trigger BOTH conditions in the same iteration:
1306+ // 1. More groups than batch_size (triggers early emission when memory pressure hits)
1307+ // 2. High cardinality ratio (triggers skip aggregation)
1308+ let batch_size = 1024 ; // We'll set this in session config
1309+ let num_groups = batch_size + 100 ; // Slightly more than batch_size (1124 groups)
1310+
1311+ // Create exactly 1 row per group = 100% cardinality ratio
1312+ let group_ids: Vec < i32 > = ( 0 ..num_groups as i32 ) . collect ( ) ;
1313+ let values: Vec < i64 > = vec ! [ 1 ; num_groups] ;
1314+
1315+ let batch = RecordBatch :: try_new (
1316+ Arc :: clone ( & schema) ,
1317+ vec ! [
1318+ Arc :: new( Int32Array :: from( group_ids) ) ,
1319+ Arc :: new( Int64Array :: from( values) ) ,
1320+ ] ,
1321+ ) ?;
1322+
1323+ let input_partitions = vec ! [ vec![ batch] ] ;
1324+
1325+ // Create constrained memory to trigger early emission but not completely fail
1326+ let runtime = RuntimeEnvBuilder :: default ( )
1327+ . with_memory_limit ( 1024 , 1.0 ) // small enough to start but will trigger pressure
1328+ . build_arc ( ) ?;
1329+
1330+ let mut task_ctx = TaskContext :: default ( ) . with_runtime ( runtime) ;
1331+
1332+ // Configure to trigger BOTH conditions:
1333+ // 1. Low probe threshold (triggers skip probe after few rows)
1334+ // 2. Low ratio threshold (triggers skip aggregation immediately)
1335+ // 3. Set batch_size to 1024 so our 1124 groups will trigger early emission
1336+ // This creates the race condition where both emit paths are triggered
1337+ let mut session_config = task_ctx. session_config ( ) . clone ( ) ;
1338+ session_config = session_config. set (
1339+ "datafusion.execution.batch_size" ,
1340+ & datafusion_common:: ScalarValue :: UInt64 ( Some ( 1024 ) ) ,
1341+ ) ;
1342+ session_config = session_config. set (
1343+ "datafusion.execution.skip_partial_aggregation_probe_rows_threshold" ,
1344+ & datafusion_common:: ScalarValue :: UInt64 ( Some ( 50 ) ) ,
1345+ ) ;
1346+ session_config = session_config. set (
1347+ "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold" ,
1348+ & datafusion_common:: ScalarValue :: Float64 ( Some ( 0.8 ) ) ,
1349+ ) ;
1350+ task_ctx = task_ctx. with_session_config ( session_config) ;
1351+ let task_ctx = Arc :: new ( task_ctx) ;
1352+
1353+ // Create aggregate: COUNT(*) GROUP BY group_col
1354+ let group_expr = vec ! [ ( col( "group_col" , & schema) ?, "group_col" . to_string( ) ) ] ;
1355+ let aggr_expr = vec ! [ Arc :: new(
1356+ AggregateExprBuilder :: new( count_udaf( ) , vec![ col( "value_col" , & schema) ?] )
1357+ . schema( Arc :: clone( & schema) )
1358+ . alias( "count_value" )
1359+ . build( ) ?,
1360+ ) ] ;
1361+
1362+ let exec = TestMemoryExec :: try_new ( & input_partitions, Arc :: clone ( & schema) , None ) ?;
1363+ let exec = Arc :: new ( TestMemoryExec :: update_cache ( Arc :: new ( exec) ) ) ;
1364+
1365+ // Use Partial mode where the race condition occurs
1366+ let aggregate_exec = AggregateExec :: try_new (
1367+ AggregateMode :: Partial ,
1368+ PhysicalGroupBy :: new_single ( group_expr) ,
1369+ aggr_expr,
1370+ vec ! [ None ] ,
1371+ exec,
1372+ Arc :: clone ( & schema) ,
1373+ ) ?;
1374+
1375+ // Execute and collect results
1376+ let mut stream =
1377+ GroupedHashAggregateStream :: new ( & aggregate_exec, Arc :: clone ( & task_ctx) , 0 ) ?;
1378+ let mut results = Vec :: new ( ) ;
1379+
1380+ while let Some ( result) = stream. next ( ) . await {
1381+ let batch = result?;
1382+ results. push ( batch) ;
1383+ }
1384+
1385+ // Count total groups emitted
1386+ let mut total_output_groups = 0 ;
1387+ for batch in & results {
1388+ total_output_groups += batch. num_rows ( ) ;
1389+ }
1390+
1391+ assert_eq ! (
1392+ total_output_groups, num_groups,
1393+ "Unexpected number of groups" ,
1394+ ) ;
1395+
1396+ Ok ( ( ) )
1397+ }
1398+ }
0 commit comments