@@ -96,49 +96,10 @@ impl DFRayDataFrame {
9696 ) -> PyResult < Vec < PyDFRayStage > > {
9797 let mut stages = vec ! [ ] ;
9898
99- // TODO: This can be done more efficiently, likely in one pass but I'm
100- // struggling to get the TreeNodeRecursion return values to make it do
101- // what I want. So, two steps for now
102-
103- // Step 2: we walk down this stage and replace stages earlier in the tree with
104- // RayStageReaderExecs as we will need to consume their output instead of
105- // execute that part of the tree ourselves
106- let down = |plan : Arc < dyn ExecutionPlan > | {
107- trace ! (
108- "examining plan down:\n {}" ,
109- display_plan_with_partition_counts( & plan)
110- ) ;
111-
112- if let Some ( stage_exec) = plan. as_any ( ) . downcast_ref :: < DFRayStageExec > ( ) {
113- let input = plan. children ( ) ;
114- assert ! ( input. len( ) == 1 , "RayStageExec must have exactly one child" ) ;
115- let input = input[ 0 ] ;
116-
117- trace ! (
118- "inserting a ray stage reader to consume: {} with partitioning {}" ,
119- displayable( plan. as_ref( ) ) . one_line( ) ,
120- plan. output_partitioning( ) . partition_count( )
121- ) ;
122-
123- let replacement = Arc :: new ( DFRayStageReaderExec :: try_new (
124- plan. output_partitioning ( ) . clone ( ) ,
125- input. schema ( ) ,
126- stage_exec. stage_id ,
127- ) ?) as Arc < dyn ExecutionPlan > ;
128-
129- Ok ( Transformed {
130- data : replacement,
131- transformed : true ,
132- tnr : TreeNodeRecursion :: Jump ,
133- } )
134- } else {
135- Ok ( Transformed :: no ( plan) )
136- }
137- } ;
138-
13999 let mut partition_groups = vec ! [ ] ;
140100 let mut full_partitions = false ;
141- // Step 1: we walk up the tree from the leaves to find the stages
101+ // We walk up the tree from the leaves to find the stages, record ray stages, and replace
102+ // each ray stage with a corresponding ray reader stage.
142103 let up = |plan : Arc < dyn ExecutionPlan > | {
143104 trace ! (
144105 "Examining plan up: {}" ,
@@ -151,19 +112,23 @@ impl DFRayDataFrame {
151112 assert ! ( input. len( ) == 1 , "RayStageExec must have exactly one child" ) ;
152113 let input = input[ 0 ] ;
153114
154- let fixed_plan = input. clone ( ) . transform_down ( down) ?. data ;
115+ let replacement = Arc :: new ( DFRayStageReaderExec :: try_new (
116+ plan. output_partitioning ( ) . clone ( ) ,
117+ input. schema ( ) ,
118+ stage_exec. stage_id ,
119+ ) ?) as Arc < dyn ExecutionPlan > ;
155120
156121 let stage = PyDFRayStage :: new (
157122 stage_exec. stage_id ,
158- fixed_plan ,
123+ input . clone ( ) ,
159124 partition_groups. clone ( ) ,
160125 full_partitions,
161126 ) ;
162127 partition_groups = vec ! [ ] ;
163128 full_partitions = false ;
164129
165130 stages. push ( stage) ;
166- Ok ( Transformed :: no ( plan ) )
131+ Ok ( Transformed :: yes ( replacement ) )
167132 } else if plan. as_any ( ) . downcast_ref :: < RepartitionExec > ( ) . is_some ( ) {
168133 trace ! ( "repartition exec" ) ;
169134 let ( calculated_partition_groups, replacement) = build_replacement (
0 commit comments