@@ -15,13 +15,13 @@ use crate::{
1515 vocab:: { CtxHost , CtxPartitionGroup } ,
1616} ;
1717
18- /// This is a simple execution plan that isolates a partition from the input
19- /// plan It will advertise that it has a single partition and when
20- /// asked to execute, it will execute a particular partition from the child
21- /// input plan.
18+ /// This executor isolates partitions from the input plan. It will advertise that it has all
19+ /// the partitions and when asked to execute, it will return empty streams for any partition that
20+ /// is not in its partition group.
2221///
23- /// This allows us to execute Repartition Exec's on different processes
24- /// by showing each one only a single child partition
22+ /// This allows us to execute Repartition Exec's on different processes. The idea is that each
23+ /// process reads all the entire input partitions but only outputs the partitions in its partition
24+ /// group.
2525#[ derive( Debug ) ]
2626pub struct PartitionIsolatorExec {
2727 pub input : Arc < dyn ExecutionPlan > ,
@@ -30,6 +30,12 @@ pub struct PartitionIsolatorExec {
3030}
3131
3232impl PartitionIsolatorExec {
33+ // new creates a new PartitionIsolatorExec. It will advertise that is has partition_count
34+ // partitions but return empty streams for any partitions not in its group.
35+ // TODO: Ideally, we only advertise partitions in the partition group. This way, the parent
36+ // only needs to call execute(0), execute(1) etc if there's 2 partitions in the group. Right now,
37+ // we don't know the number of partitions in the group, so we have to advertise all and the
38+ // parent will call execute(0)..execute(partition_count-1).
3339 pub fn new ( input : Arc < dyn ExecutionPlan > , partition_count : usize ) -> Self {
3440 // We advertise that we only have partition_count partitions
3541 let properties = input
@@ -73,9 +79,9 @@ impl ExecutionPlan for PartitionIsolatorExec {
7379 }
7480
7581 fn with_new_children (
76- self : std :: sync :: Arc < Self > ,
77- children : Vec < std :: sync :: Arc < dyn ExecutionPlan > > ,
78- ) -> Result < std :: sync :: Arc < dyn ExecutionPlan > > {
82+ self : Arc < Self > ,
83+ children : Vec < Arc < dyn ExecutionPlan > > ,
84+ ) -> Result < Arc < dyn ExecutionPlan > > {
7985 // TODO: generalize this
8086 assert_eq ! ( children. len( ) , 1 ) ;
8187 Ok ( Arc :: new ( Self :: new (
@@ -87,7 +93,7 @@ impl ExecutionPlan for PartitionIsolatorExec {
8793 fn execute (
8894 & self ,
8995 partition : usize ,
90- context : std :: sync :: Arc < datafusion:: execution:: TaskContext > ,
96+ context : Arc < datafusion:: execution:: TaskContext > ,
9197 ) -> Result < SendableRecordBatchStream > {
9298 let config = context. session_config ( ) ;
9399 let partition_group = & config
@@ -116,30 +122,150 @@ impl ExecutionPlan for PartitionIsolatorExec {
116122
117123 let partitions_in_input = self . input . output_partitioning ( ) . partition_count ( ) as u64 ;
118124
119- let output_stream = match partition_group. get ( partition) {
120- Some ( actual_partition_number) => {
125+ if partition_group. len ( ) == 0 {
126+ trace ! (
127+ "{} returning empty stream due to empty partition group" ,
128+ ctx_name
129+ ) ;
130+ return Ok ( Box :: pin ( EmptyRecordBatchStream :: new ( self . input . schema ( ) ) )
131+ as SendableRecordBatchStream ) ;
132+ }
133+
134+ // TODO(#59): This is inefficient. Once partition groups are well defined ranges, this
135+ // check will be faster.
136+ match partition_group. contains ( & ( partition as u64 ) ) {
137+ true => {
121138 trace ! (
122139 "PartitionIsolatorExec::execute: {}, partition_group={:?}, requested \
123- partition={} actual={}, \n input partitions={}",
140+ partition={} \n input partitions={}",
124141 ctx_name,
125142 partition_group,
126143 partition,
127- * actual_partition_number,
128144 partitions_in_input
129145 ) ;
130- if * actual_partition_number >= partitions_in_input {
131- trace ! ( "{} returning empty stream" , ctx_name) ;
132- Ok ( Box :: pin ( EmptyRecordBatchStream :: new ( self . input . schema ( ) ) )
133- as SendableRecordBatchStream )
134- } else {
135- trace ! ( "{} returning actual stream" , ctx_name) ;
136- self . input
137- . execute ( * actual_partition_number as usize , context)
138- }
146+ trace ! ( "{} returning actual stream" , ctx_name) ;
147+ self . input . execute ( partition, context)
148+ }
149+ false => {
150+ trace ! ( "{} returning empty stream" , ctx_name) ;
151+ Ok ( Box :: pin ( EmptyRecordBatchStream :: new ( self . input . schema ( ) ) )
152+ as SendableRecordBatchStream )
139153 }
140- None => Ok ( Box :: pin ( EmptyRecordBatchStream :: new ( self . input . schema ( ) ) )
141- as SendableRecordBatchStream ) ,
142- } ;
143- output_stream
154+ }
155+ }
156+ }
157+
158+ #[ cfg( test) ]
159+ mod tests {
160+ use super :: * ;
161+ use crate :: { record_batch_exec:: RecordBatchExec , vocab:: CtxPartitionGroup } ;
162+ use arrow:: array:: { Int32Array , RecordBatch } ;
163+ use datafusion:: {
164+ arrow:: datatypes:: { DataType , Field , Schema } ,
165+ prelude:: SessionContext ,
166+ } ;
167+ use futures:: StreamExt ;
168+ use std:: sync:: Arc ;
169+
170+ fn create_test_record_batch_exec ( ) -> Arc < dyn ExecutionPlan > {
171+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new(
172+ "col1" ,
173+ DataType :: Int32 ,
174+ false ,
175+ ) ] ) ) ;
176+ let batch = RecordBatch :: try_new (
177+ schema. clone ( ) ,
178+ vec ! [ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) ] ,
179+ )
180+ . unwrap ( ) ;
181+ Arc :: new ( RecordBatchExec :: new ( batch) )
182+ }
183+
184+ #[ test]
185+ fn test_partition_isolator_exec ( ) {
186+ let input = create_test_record_batch_exec ( ) ;
187+ let partition_count = 3 ;
188+ let isolator = PartitionIsolatorExec :: new ( input, partition_count) ;
189+
190+ // Test success case: valid partition with partition group
191+ let ctx = SessionContext :: new ( ) ;
192+ let partition_group = vec ! [ 0u64 , 1u64 , 2u64 ] ;
193+ {
194+ let state = ctx. state_ref ( ) ;
195+ let mut guard = state. write ( ) ;
196+ let config = guard. config_mut ( ) ;
197+ config. set_extension ( Arc :: new ( CtxPartitionGroup ( partition_group) ) ) ;
198+ }
199+
200+ let task_context = ctx. task_ctx ( ) ;
201+
202+ // Success case: execute valid partition
203+ let result = isolator. execute ( 0 , task_context. clone ( ) ) ;
204+ assert ! ( result. is_ok( ) ) ;
205+
206+ // Error case: try to execute partition beyond partition_count
207+ let result = isolator. execute ( 4 , task_context. clone ( ) ) ;
208+ assert ! ( result. is_err( ) ) ;
209+ assert ! ( result
210+ . err( )
211+ . unwrap( )
212+ . to_string( )
213+ . contains( "Invalid partition 4 for PartitionIsolatorExec" ) ) ;
214+
215+ // Error case: test empty task context (missing group extension)
216+ let empty_ctx = SessionContext :: new ( ) ;
217+ let empty_task_context = empty_ctx. task_ctx ( ) ;
218+
219+ let result = isolator. execute ( 0 , empty_task_context. clone ( ) ) ;
220+ assert ! ( result. is_err( ) ) ;
221+ assert ! ( result
222+ . err( )
223+ . unwrap( )
224+ . to_string( )
225+ . contains( "PartitionGroup not set in session config" ) ) ;
226+
227+ let result = isolator. execute ( 1 , empty_task_context) ;
228+ assert ! ( result. is_err( ) ) ;
229+ assert ! ( result
230+ . err( )
231+ . unwrap( )
232+ . to_string( )
233+ . contains( "PartitionGroup not set in session config" ) ) ;
234+ }
235+
236+ #[ tokio:: test]
237+ async fn test_partition_isolator_exec_with_group ( ) {
238+ let input = create_test_record_batch_exec ( ) ;
239+ let partition_count = 6 ;
240+ let isolator = PartitionIsolatorExec :: new ( input, partition_count) ;
241+
242+ // Partition group is a subset of the partitions.
243+ let ctx = SessionContext :: new ( ) ;
244+ let partition_group = vec ! [ 1u64 , 2u64 , 3u64 , 4u64 ] ;
245+ {
246+ let state = ctx. state_ref ( ) ;
247+ let mut guard = state. write ( ) ;
248+ let config = guard. config_mut ( ) ;
249+ config. set_extension ( Arc :: new ( CtxPartitionGroup ( partition_group) ) ) ;
250+ }
251+
252+ let task_context = ctx. task_ctx ( ) ;
253+ for i in 0 ..6 {
254+ let result = isolator. execute ( i, task_context. clone ( ) ) ;
255+ assert ! ( result. is_ok( ) ) ;
256+ let mut stream = result. unwrap ( ) ;
257+ let next_batch = stream. next ( ) . await ;
258+ if i == 0 || i == 5 {
259+ assert ! (
260+ next_batch. is_none( ) ,
261+ "Expected EmptyRecordBatchStream to produce no batches"
262+ ) ;
263+ } else {
264+ assert ! (
265+ next_batch. is_some( ) ,
266+ "Expected Stream to produce non-empty batches"
267+ ) ;
268+ }
269+ }
144270 }
145271}
0 commit comments