@@ -23,6 +23,7 @@ use datafusion::prelude::SessionContext;
2323use datafusion_proto:: bytes:: physical_plan_from_bytes_with_extension_codec;
2424use pyo3:: prelude:: * ;
2525use pyo3:: types:: PyBytes ;
26+ use std:: collections:: HashSet ;
2627use std:: sync:: Arc ;
2728
2829#[ pyclass( name = "QueryStage" , module = "datafusion_ray" , subclass) ]
@@ -99,14 +100,23 @@ impl QueryStage {
99100 /// Get the input partition count. This is the same as the number of concurrent tasks
100101 /// when we schedule this query stage for execution
101102 pub fn get_input_partition_count ( & self ) -> usize {
102- if self . plan . children ( ) . is_empty ( ) {
103- // leaf node (file scan)
104- self . plan . output_partitioning ( ) . partition_count ( )
105- } else {
106- self . plan . children ( ) [ 0 ]
107- . output_partitioning ( )
108- . partition_count ( )
103+ let mut output_partition_counts = HashSet :: new ( ) ;
104+
105+ for child in self . plan . children ( ) {
106+ output_partition_counts. insert ( child. output_partitioning ( ) . partition_count ( ) ) ;
107+ if output_partition_counts. len ( ) > 1 {
108+ panic ! (
109+ "Children plan of {:#?} have a distinct outout partitioning partition count" ,
110+ self . plan
111+ ) ;
112+ }
109113 }
114+ // If this stage is a leaf node (file scan), it won't have children
115+ // so we return the partition count of the plan itself
116+ output_partition_counts
117+ . into_iter ( )
118+ . next ( )
119+ . unwrap_or ( self . plan . output_partitioning ( ) . partition_count ( ) )
110120 }
111121
112122 pub fn get_output_partition_count ( & self ) -> usize {
0 commit comments