1616// under the License.
1717
1818use crate :: context:: serialize_execution_plan;
19- use crate :: shuffle:: { ShuffleCodec , ShuffleReaderExec } ;
19+ use crate :: shuffle:: { ShuffleCodec , ShuffleReaderExec , ShuffleWriterExec } ;
2020use datafusion:: error:: Result ;
2121use datafusion:: physical_plan:: { ExecutionPlan , ExecutionPlanProperties , Partitioning } ;
2222use datafusion:: prelude:: SessionContext ;
@@ -60,8 +60,8 @@ impl PyQueryStage {
6060 self . stage . get_child_stage_ids ( )
6161 }
6262
63- pub fn get_input_partition_count ( & self ) -> usize {
64- self . stage . get_input_partition_count ( )
63+ pub fn get_execution_partition_count ( & self ) -> usize {
64+ self . stage . get_execution_partition_count ( )
6565 }
6666
6767 pub fn get_output_partition_count ( & self ) -> usize {
@@ -75,16 +75,6 @@ pub struct QueryStage {
7575 pub plan : Arc < dyn ExecutionPlan > ,
7676}
7777
78- fn _get_output_partition_count ( plan : & dyn ExecutionPlan ) -> usize {
79- // UnknownPartitioning and HashPartitioning with empty expressions will
80- // both return 1 partition.
81- match plan. properties ( ) . output_partitioning ( ) {
82- Partitioning :: UnknownPartitioning ( _) => 1 ,
83- Partitioning :: Hash ( expr, _) if expr. is_empty ( ) => 1 ,
84- p => p. partition_count ( ) ,
85- }
86- }
87-
8878impl QueryStage {
8979 pub fn new ( id : usize , plan : Arc < dyn ExecutionPlan > ) -> Self {
9080 Self { id, plan }
@@ -96,21 +86,27 @@ impl QueryStage {
9686 ids
9787 }
9888
99- /// Get the input partition count. This is the same as the number of concurrent tasks
100- /// when we schedule this query stage for execution
101- pub fn get_input_partition_count ( & self ) -> usize {
102- if self . plan . children ( ) . is_empty ( ) {
103- // leaf node (file scan)
104- self . plan . output_partitioning ( ) . partition_count ( )
89+ /// Get the number of partitions that can be executed in parallel
90+ pub fn get_execution_partition_count ( & self ) -> usize {
91+ if let Some ( shuffle) = self . plan . as_any ( ) . downcast_ref :: < ShuffleWriterExec > ( ) {
92+ // use the partitioning of the input to the shuffle write because we are
93+ // really executing that and then using the shuffle writer to repartition
94+ // the output
95+ shuffle. input_plan . output_partitioning ( ) . partition_count ( )
10596 } else {
106- self . plan . children ( ) [ 0 ]
107- . output_partitioning ( )
108- . partition_count ( )
97+ // for any other plan, use its output partitioning
98+ self . plan . output_partitioning ( ) . partition_count ( )
10999 }
110100 }
111101
112102 pub fn get_output_partition_count ( & self ) -> usize {
113- _get_output_partition_count ( self . plan . as_ref ( ) )
103+ // UnknownPartitioning and HashPartitioning with empty expressions will
104+ // both return 1 partition.
105+ match self . plan . properties ( ) . output_partitioning ( ) {
106+ Partitioning :: UnknownPartitioning ( _) => 1 ,
107+ Partitioning :: Hash ( expr, _) if expr. is_empty ( ) => 1 ,
108+ p => p. partition_count ( ) ,
109+ }
114110 }
115111}
116112
0 commit comments