Skip to content

Commit 9152dc2

Browse files
chore: Remove num partitions from repartitioner (apache#1498)
* chore: Remove num partitions from repartitioner treat all single partition schemes the same way remove duplicate messaging by using assert_eq instead of assert * rebased and resolved conflicts --------- Co-authored-by: Emily Matheys <[email protected]>
1 parent 15b1152 commit 9152dc2

File tree

1 file changed

+18
-28
lines changed

1 file changed

+18
-28
lines changed

native/core/src/execution/shuffle/shuffle_writer.rs

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ struct ShuffleRepartitioner {
290290
buffered_partitions: Vec<PartitionBuffer>,
291291
/// Partitioning scheme to use
292292
partitioning: Partitioning,
293-
num_output_partitions: usize,
294293
runtime: Arc<RuntimeEnv>,
295294
metrics: ShuffleRepartitionerMetrics,
296295
/// Hashes for each row in the current batch
@@ -315,8 +314,6 @@ impl ShuffleRepartitioner {
315314
codec: CompressionCodec,
316315
enable_fast_encoding: bool,
317316
) -> Result<Self> {
318-
let num_output_partitions = partitioning.partition_count();
319-
320317
let mut hashes_buf = Vec::with_capacity(batch_size);
321318
let mut partition_ids = Vec::with_capacity(batch_size);
322319

@@ -336,7 +333,7 @@ impl ShuffleRepartitioner {
336333
output_data_file,
337334
output_index_file,
338335
schema: Arc::clone(&schema),
339-
buffered_partitions: (0..num_output_partitions)
336+
buffered_partitions: (0..partitioning.partition_count())
340337
.map(|_| {
341338
PartitionBuffer::try_new(
342339
Arc::clone(&schema),
@@ -348,7 +345,6 @@ impl ShuffleRepartitioner {
348345
})
349346
.collect::<Result<Vec<_>>>()?,
350347
partitioning,
351-
num_output_partitions,
352348
runtime,
353349
metrics,
354350
hashes_buf,
@@ -401,9 +397,21 @@ impl ShuffleRepartitioner {
401397
// number of rows those are written to output data file.
402398
self.metrics.baseline.record_output(input.num_rows());
403399

404-
let num_output_partitions = self.num_output_partitions;
405400
match &self.partitioning {
406-
Partitioning::Hash(exprs, _) => {
401+
any if any.partition_count() == 1 => {
402+
let buffered_partitions = &mut self.buffered_partitions;
403+
404+
assert_eq!(buffered_partitions.len(), 1, "Expected 1 partition");
405+
406+
// TODO the single partition case could be optimized to avoid appending all
407+
// rows from the batch into builders and then recreating the batch
408+
// https://github.com/apache/datafusion-comet/issues/1453
409+
let indices = (0..input.num_rows()).collect::<Vec<usize>>();
410+
411+
self.append_rows_to_partition(input.columns(), &indices, 0)
412+
.await?;
413+
}
414+
Partitioning::Hash(exprs, num_output_partitions) => {
407415
let (partition_starts, shuffled_partition_ids): (Vec<usize>, Vec<usize>) = {
408416
let mut timer = self.metrics.repart_time.timer();
409417

@@ -423,11 +431,11 @@ impl ShuffleRepartitioner {
423431
.iter()
424432
.enumerate()
425433
.for_each(|(idx, hash)| {
426-
partition_ids[idx] = pmod(*hash, num_output_partitions) as u64
434+
partition_ids[idx] = pmod(*hash, *num_output_partitions) as u64
427435
});
428436

429437
// count each partition size
430-
let mut partition_counters = vec![0usize; num_output_partitions];
438+
let mut partition_counters = vec![0usize; *num_output_partitions];
431439
partition_ids
432440
.iter()
433441
.for_each(|partition_id| partition_counters[*partition_id as usize] += 1);
@@ -478,24 +486,6 @@ impl ShuffleRepartitioner {
478486
.await?;
479487
}
480488
}
481-
Partitioning::UnknownPartitioning(n) if *n == 1 => {
482-
let buffered_partitions = &mut self.buffered_partitions;
483-
484-
assert_eq!(
485-
buffered_partitions.len(),
486-
1,
487-
"Expected 1 partition but got {}",
488-
buffered_partitions.len()
489-
);
490-
491-
// TODO the single partition case could be optimized to avoid appending all
492-
// rows from the batch into builders and then recreating the batch
493-
// https://github.com/apache/datafusion-comet/issues/1453
494-
let indices = (0..input.num_rows()).collect::<Vec<usize>>();
495-
496-
self.append_rows_to_partition(input.columns(), &indices, 0)
497-
.await?;
498-
}
499489
other => {
500490
// this should be unreachable as long as the validation logic
501491
// in the constructor is kept up-to-date
@@ -511,8 +501,8 @@ impl ShuffleRepartitioner {
511501
/// Writes buffered shuffled record batches into Arrow IPC bytes.
512502
async fn shuffle_write(&mut self) -> Result<SendableRecordBatchStream> {
513503
let mut elapsed_compute = self.metrics.baseline.elapsed_compute().timer();
514-
let num_output_partitions = self.num_output_partitions;
515504
let buffered_partitions = &mut self.buffered_partitions;
505+
let num_output_partitions = buffered_partitions.len();
516506
let mut output_batches: Vec<Vec<u8>> = vec![vec![]; num_output_partitions];
517507
let mut offsets = vec![0; num_output_partitions + 1];
518508
for i in 0..num_output_partitions {

0 commit comments

Comments
 (0)