Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/codec.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::sync::Arc;

use arrow::datatypes::Schema;
use arrow::datatypes::{Field, Schema};
use datafusion::{
common::{internal_datafusion_err, internal_err, Result},
execution::FunctionRegistry,
Expand Down Expand Up @@ -243,10 +243,10 @@ mod test {
isolator::PartitionIsolatorExec, max_rows::MaxRowsExec, stage_reader::DDStageReaderExec,
};

fn create_test_schema() -> Arc<arrow::datatypes::Schema> {
Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new("a", DataType::Int32, false),
arrow::datatypes::Field::new("b", DataType::Int32, false),
fn create_test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]))
}

Expand Down
129 changes: 129 additions & 0 deletions src/distribution_strategy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// A Grouper is used to group partitions.
pub trait Grouper {
// group groups the number of partitions into a vec of groups.
fn group(&self, num_partitions: usize) -> Vec<PartitionGroup>;
}

// PartitionGroup is a struct that represents a range of partitions from [start, end). This is
// more space efficient than a vector of u64s.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct PartitionGroup {
start: usize,
end: usize,
}

impl PartitionGroup {
// new creates a new PartitionGroup containing partitions in the range [start..end).
pub fn new(start: usize, end: usize) -> Self {
Self { start, end }
}

// start is the first in the range
pub fn start(&self) -> usize {
self.start
}

// end is the exclusive end partition in the range
pub fn end(&self) -> usize {
self.end
}
}

// PartitionGrouper groups a number partitions together depending on a partition_group_size.
// Ex. 10 partitions with a group size of 3 will yield groups [(0..3), (3..6), (6..9), (9)].
// - A partition_group_size of 0 will panic
// - Grouping 0 partitions will result an empty vec
// - It's possible for bad groupings to exist. Ex. if the group size is 99 and there are 100
// partitions, then you will get unbalanced partitions [(0..99), (99)]
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct PartitionGrouper {
partition_group_size: usize,
}

impl PartitionGrouper {
pub fn new(partition_group_size: usize) -> Self {
assert!(
partition_group_size > 0,
"partition groups cannot be size 0"
);
PartitionGrouper {
partition_group_size,
}
}
}

impl Grouper for PartitionGrouper {
// group implements the Grouper trait
fn group(&self, num_partitions: usize) -> Vec<PartitionGroup> {
(0..num_partitions)
.step_by(self.partition_group_size)
.map(|start| {
let end = std::cmp::min(start + self.partition_group_size, num_partitions);
PartitionGroup { start, end }
})
.collect()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_partition_grouper_basic() {
let grouper = PartitionGrouper::new(4);
let groups = grouper.group(10);

let expected = vec![
PartitionGroup { start: 0, end: 4 },
PartitionGroup { start: 4, end: 8 },
PartitionGroup { start: 8, end: 10 },
];

assert_eq!(groups, expected);
}

#[test]
fn test_partition_grouper_uneven() {
let grouper = PartitionGrouper::new(2);
let groups = grouper.group(5);

let expected = vec![
PartitionGroup { start: 0, end: 2 },
PartitionGroup { start: 2, end: 4 },
PartitionGroup { start: 4, end: 5 },
];

assert_eq!(groups, expected);
}

#[test]
#[should_panic]
fn test_invalid_group_size() {
PartitionGrouper::new(0);
}

#[test]
fn test_num_partitions_smaller_than_group_size() {
let g = PartitionGrouper::new(2);
let groups = g.group(1);
let expected = vec![PartitionGroup { start: 0, end: 1 }];
assert_eq!(groups, expected);
}

#[test]
fn test_num_partitions_equal_to_group_size() {
let g = PartitionGrouper::new(2);
let groups = g.group(2);
let expected = vec![PartitionGroup { start: 0, end: 2 }];
assert_eq!(groups, expected);
}

#[test]
fn test_zero_partitions_to_group() {
let g = PartitionGrouper::new(2);
let groups = g.group(0);
let expected = vec![];
assert_eq!(groups, expected);
}
}
180 changes: 153 additions & 27 deletions src/isolator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ use crate::{
vocab::{CtxHost, CtxPartitionGroup},
};

/// This is a simple execution plan that isolates a partition from the input
/// plan It will advertise that it has a single partition and when
/// asked to execute, it will execute a particular partition from the child
/// input plan.
/// This executor isolates partitions from the input plan. It will advertise that it has all
/// the partitions and when asked to execute, it will return empty streams for any partition that
/// is not in its partition group.
///
/// This allows us to execute Repartition Exec's on different processes
/// by showing each one only a single child partition
/// This allows us to execute Repartition Exec's on different processes. The idea is that each
/// process reads all the entire input partitions but only outputs the partitions in its partition
/// group.
#[derive(Debug)]
pub struct PartitionIsolatorExec {
pub input: Arc<dyn ExecutionPlan>,
Expand All @@ -30,6 +30,12 @@ pub struct PartitionIsolatorExec {
}

impl PartitionIsolatorExec {
// new creates a new PartitionIsolatorExec. It will advertise that is has partition_count
// partitions but return empty streams for any partitions not in its group.
// TODO: Ideally, we only advertise partitions in the partition group. This way, the parent
// only needs to call execute(0), execute(1) etc if there's 2 partitions in the group. Right now,
// we don't know the number of partitions in the group, so we have to advertise all and the
// parent will call execute(0)..execute(partition_count-1).
pub fn new(input: Arc<dyn ExecutionPlan>, partition_count: usize) -> Self {
// We advertise that we only have partition_count partitions
let properties = input
Expand Down Expand Up @@ -73,9 +79,9 @@ impl ExecutionPlan for PartitionIsolatorExec {
}

fn with_new_children(
self: std::sync::Arc<Self>,
children: Vec<std::sync::Arc<dyn ExecutionPlan>>,
) -> Result<std::sync::Arc<dyn ExecutionPlan>> {
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
// TODO: generalize this
assert_eq!(children.len(), 1);
Ok(Arc::new(Self::new(
Expand All @@ -87,7 +93,7 @@ impl ExecutionPlan for PartitionIsolatorExec {
fn execute(
&self,
partition: usize,
context: std::sync::Arc<datafusion::execution::TaskContext>,
context: Arc<datafusion::execution::TaskContext>,
) -> Result<SendableRecordBatchStream> {
let config = context.session_config();
let partition_group = &config
Expand Down Expand Up @@ -116,30 +122,150 @@ impl ExecutionPlan for PartitionIsolatorExec {

let partitions_in_input = self.input.output_partitioning().partition_count() as u64;

let output_stream = match partition_group.get(partition) {
Some(actual_partition_number) => {
if partition_group.len() == 0 {
trace!(
"{} returning empty stream due to empty partition group",
ctx_name
);
return Ok(Box::pin(EmptyRecordBatchStream::new(self.input.schema()))
as SendableRecordBatchStream);
}

// TODO(#59): This is inefficient. Once partition groups are well defined ranges, this
// check will be faster.
match partition_group.contains(&(partition as u64)) {
true => {
trace!(
"PartitionIsolatorExec::execute: {}, partition_group={:?}, requested \
partition={} actual={},\ninput partitions={}",
partition={} \ninput partitions={}",
ctx_name,
partition_group,
partition,
*actual_partition_number,
partitions_in_input
);
if *actual_partition_number >= partitions_in_input {
trace!("{} returning empty stream", ctx_name);
Ok(Box::pin(EmptyRecordBatchStream::new(self.input.schema()))
as SendableRecordBatchStream)
} else {
trace!("{} returning actual stream", ctx_name);
self.input
.execute(*actual_partition_number as usize, context)
}
trace!("{} returning actual stream", ctx_name);
self.input.execute(partition, context)
}
false => {
trace!("{} returning empty stream", ctx_name);
Ok(Box::pin(EmptyRecordBatchStream::new(self.input.schema()))
as SendableRecordBatchStream)
}
None => Ok(Box::pin(EmptyRecordBatchStream::new(self.input.schema()))
as SendableRecordBatchStream),
};
output_stream
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{record_batch_exec::RecordBatchExec, vocab::CtxPartitionGroup};
use arrow::array::{Int32Array, RecordBatch};
use datafusion::{
arrow::datatypes::{DataType, Field, Schema},
prelude::SessionContext,
};
use futures::StreamExt;
use std::sync::Arc;

fn create_test_record_batch_exec() -> Arc<dyn ExecutionPlan> {
let schema = Arc::new(Schema::new(vec![Field::new(
"col1",
DataType::Int32,
false,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
Arc::new(RecordBatchExec::new(batch))
}

#[test]
fn test_partition_isolator_exec() {
let input = create_test_record_batch_exec();
let partition_count = 3;
let isolator = PartitionIsolatorExec::new(input, partition_count);

// Test success case: valid partition with partition group
let ctx = SessionContext::new();
let partition_group = vec![0u64, 1u64, 2u64];
{
let state = ctx.state_ref();
let mut guard = state.write();
let config = guard.config_mut();
config.set_extension(Arc::new(CtxPartitionGroup(partition_group)));
}

let task_context = ctx.task_ctx();

// Success case: execute valid partition
let result = isolator.execute(0, task_context.clone());
assert!(result.is_ok());

// Error case: try to execute partition beyond partition_count
let result = isolator.execute(4, task_context.clone());
assert!(result.is_err());
assert!(result
.err()
.unwrap()
.to_string()
.contains("Invalid partition 4 for PartitionIsolatorExec"));

// Error case: test empty task context (missing group extension)
let empty_ctx = SessionContext::new();
let empty_task_context = empty_ctx.task_ctx();

let result = isolator.execute(0, empty_task_context.clone());
assert!(result.is_err());
assert!(result
.err()
.unwrap()
.to_string()
.contains("PartitionGroup not set in session config"));

let result = isolator.execute(1, empty_task_context);
assert!(result.is_err());
assert!(result
.err()
.unwrap()
.to_string()
.contains("PartitionGroup not set in session config"));
}

#[tokio::test]
async fn test_partition_isolator_exec_with_group() {
let input = create_test_record_batch_exec();
let partition_count = 6;
let isolator = PartitionIsolatorExec::new(input, partition_count);

// Partition group is a subset of the partitions.
let ctx = SessionContext::new();
let partition_group = vec![1u64, 2u64, 3u64, 4u64];
{
let state = ctx.state_ref();
let mut guard = state.write();
let config = guard.config_mut();
config.set_extension(Arc::new(CtxPartitionGroup(partition_group)));
}

let task_context = ctx.task_ctx();
for i in 0..6 {
let result = isolator.execute(i, task_context.clone());
assert!(result.is_ok());
let mut stream = result.unwrap();
let next_batch = stream.next().await;
if i == 0 || i == 5 {
assert!(
next_batch.is_none(),
"Expected EmptyRecordBatchStream to produce no batches"
);
} else {
assert!(
next_batch.is_some(),
"Expected Stream to produce non-empty batches"
);
}
}
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub use proto::generated::protobuf;
pub mod analyze;
pub mod codec;
pub mod customizer;
pub mod distribution_strategy;
pub mod explain;
pub mod flight;
pub mod friendly;
Expand Down
Loading
Loading