@@ -290,7 +290,6 @@ struct ShuffleRepartitioner {
290290 buffered_partitions : Vec < PartitionBuffer > ,
291291 /// Partitioning scheme to use
292292 partitioning : Partitioning ,
293- num_output_partitions : usize ,
294293 runtime : Arc < RuntimeEnv > ,
295294 metrics : ShuffleRepartitionerMetrics ,
296295 /// Hashes for each row in the current batch
@@ -315,8 +314,6 @@ impl ShuffleRepartitioner {
315314 codec : CompressionCodec ,
316315 enable_fast_encoding : bool ,
317316 ) -> Result < Self > {
318- let num_output_partitions = partitioning. partition_count ( ) ;
319-
320317 let mut hashes_buf = Vec :: with_capacity ( batch_size) ;
321318 let mut partition_ids = Vec :: with_capacity ( batch_size) ;
322319
@@ -336,7 +333,7 @@ impl ShuffleRepartitioner {
336333 output_data_file,
337334 output_index_file,
338335 schema : Arc :: clone ( & schema) ,
339- buffered_partitions : ( 0 ..num_output_partitions )
336+ buffered_partitions : ( 0 ..partitioning . partition_count ( ) )
340337 . map ( |_| {
341338 PartitionBuffer :: try_new (
342339 Arc :: clone ( & schema) ,
@@ -348,7 +345,6 @@ impl ShuffleRepartitioner {
348345 } )
349346 . collect :: < Result < Vec < _ > > > ( ) ?,
350347 partitioning,
351- num_output_partitions,
352348 runtime,
353349 metrics,
354350 hashes_buf,
@@ -401,9 +397,21 @@ impl ShuffleRepartitioner {
401397 // number of rows those are written to output data file.
402398 self . metrics . baseline . record_output ( input. num_rows ( ) ) ;
403399
404- let num_output_partitions = self . num_output_partitions ;
405400 match & self . partitioning {
406- Partitioning :: Hash ( exprs, _) => {
401+ any if any. partition_count ( ) == 1 => {
402+ let buffered_partitions = & mut self . buffered_partitions ;
403+
404+ assert_eq ! ( buffered_partitions. len( ) , 1 , "Expected 1 partition" ) ;
405+
406+ // TODO the single partition case could be optimized to avoid appending all
407+ // rows from the batch into builders and then recreating the batch
408+ // https://github.com/apache/datafusion-comet/issues/1453
409+ let indices = ( 0 ..input. num_rows ( ) ) . collect :: < Vec < usize > > ( ) ;
410+
411+ self . append_rows_to_partition ( input. columns ( ) , & indices, 0 )
412+ . await ?;
413+ }
414+ Partitioning :: Hash ( exprs, num_output_partitions) => {
407415 let ( partition_starts, shuffled_partition_ids) : ( Vec < usize > , Vec < usize > ) = {
408416 let mut timer = self . metrics . repart_time . timer ( ) ;
409417
@@ -423,11 +431,11 @@ impl ShuffleRepartitioner {
423431 . iter ( )
424432 . enumerate ( )
425433 . for_each ( |( idx, hash) | {
426- partition_ids[ idx] = pmod ( * hash, num_output_partitions) as u64
434+ partition_ids[ idx] = pmod ( * hash, * num_output_partitions) as u64
427435 } ) ;
428436
429437 // count each partition size
430- let mut partition_counters = vec ! [ 0usize ; num_output_partitions] ;
438+ let mut partition_counters = vec ! [ 0usize ; * num_output_partitions] ;
431439 partition_ids
432440 . iter ( )
433441 . for_each ( |partition_id| partition_counters[ * partition_id as usize ] += 1 ) ;
@@ -478,24 +486,6 @@ impl ShuffleRepartitioner {
478486 . await ?;
479487 }
480488 }
481- Partitioning :: UnknownPartitioning ( n) if * n == 1 => {
482- let buffered_partitions = & mut self . buffered_partitions ;
483-
484- assert_eq ! (
485- buffered_partitions. len( ) ,
486- 1 ,
487- "Expected 1 partition but got {}" ,
488- buffered_partitions. len( )
489- ) ;
490-
491- // TODO the single partition case could be optimized to avoid appending all
492- // rows from the batch into builders and then recreating the batch
493- // https://github.com/apache/datafusion-comet/issues/1453
494- let indices = ( 0 ..input. num_rows ( ) ) . collect :: < Vec < usize > > ( ) ;
495-
496- self . append_rows_to_partition ( input. columns ( ) , & indices, 0 )
497- . await ?;
498- }
499489 other => {
500490 // this should be unreachable as long as the validation logic
501491 // in the constructor is kept up-to-date
@@ -511,8 +501,8 @@ impl ShuffleRepartitioner {
511501 /// Writes buffered shuffled record batches into Arrow IPC bytes.
512502 async fn shuffle_write ( & mut self ) -> Result < SendableRecordBatchStream > {
513503 let mut elapsed_compute = self . metrics . baseline . elapsed_compute ( ) . timer ( ) ;
514- let num_output_partitions = self . num_output_partitions ;
515504 let buffered_partitions = & mut self . buffered_partitions ;
505+ let num_output_partitions = buffered_partitions. len ( ) ;
516506 let mut output_batches: Vec < Vec < u8 > > = vec ! [ vec![ ] ; num_output_partitions] ;
517507 let mut offsets = vec ! [ 0 ; num_output_partitions + 1 ] ;
518508 for i in 0 ..num_output_partitions {
0 commit comments