@@ -37,7 +37,6 @@ use datafusion_proto::bytes::{
3737} ;
3838use datafusion_proto:: physical_plan:: { AsExecutionPlan , DefaultPhysicalExtensionCodec } ;
3939use datafusion_proto:: protobuf;
40- use datafusion_python:: physical_plan:: PyExecutionPlan ;
4140use futures:: StreamExt ;
4241use prost:: { DecodeError , Message } ;
4342use pyo3:: exceptions:: PyRuntimeError ;
@@ -56,6 +55,26 @@ pub struct PyContext {
5655 use_ray_shuffle : bool ,
5756}
5857
58+ pub ( crate ) fn execution_plan_from_pyany (
59+ py_plan : & Bound < PyAny > ,
60+ ) -> PyResult < Arc < dyn ExecutionPlan > > {
61+ let py_proto = py_plan. call_method0 ( "to_proto" ) ?;
62+ let plan_bytes: & [ u8 ] = py_proto. extract ( ) ?;
63+ let plan_node = protobuf:: PhysicalPlanNode :: try_decode ( plan_bytes) . map_err ( |e| {
64+ PyRuntimeError :: new_err ( format ! (
65+ "Unable to decode physical plan protobuf message: {}" ,
66+ e
67+ ) )
68+ } ) ?;
69+
70+ let codec = DefaultPhysicalExtensionCodec { } ;
71+ let runtime = RuntimeEnv :: default ( ) ;
72+ let registry = SessionContext :: new ( ) ;
73+ plan_node
74+ . try_into_physical_plan ( & registry, & runtime, & codec)
75+ . map_err ( |e| e. into ( ) )
76+ }
77+
5978#[ pymethods]
6079impl PyContext {
6180 #[ new]
@@ -117,20 +136,9 @@ impl PyContext {
117136 // let df = wait_for_future(py, self.ctx.sql(sql))?;
118137 let py_df = self . run_sql ( sql, py) ?;
119138 let py_plan = py_df. call_method0 ( py, "execution_plan" ) ?;
120- let py_proto = py_plan. call_method0 ( py, "to_proto" ) ?;
121- let plan_bytes: & [ u8 ] = py_proto. extract ( py) ?;
122- let plan_node = protobuf:: PhysicalPlanNode :: decode ( plan_bytes) . map_err ( |e| {
123- PyRuntimeError :: new_err ( format ! (
124- "Unable to decode physical plan protobuf message: {}" ,
125- e
126- ) )
127- } ) ?;
128-
129- let codec = DefaultPhysicalExtensionCodec { } ;
130- let runtime = RuntimeEnv :: default ( ) ;
131- let registry = SessionContext :: new ( ) ;
132- let plan = plan_node. try_into_physical_plan ( & registry, & runtime, & codec) ?;
139+ let py_plan = py_plan. bind ( py) ;
133140
141+ let plan = execution_plan_from_pyany ( py_plan) ?;
134142 let graph = make_execution_graph ( plan. clone ( ) , self . use_ray_shuffle ) ?;
135143
136144 // debug logging
@@ -150,7 +158,7 @@ impl PyContext {
150158 /// Execute a partition of a query plan. This will typically be executing a shuffle write and write the results to disk
151159 pub fn execute_partition (
152160 & self ,
153- plan : PyExecutionPlan ,
161+ plan : & Bound < ' _ , PyAny > ,
154162 part : usize ,
155163 inputs : PyObject ,
156164 py : Python ,
@@ -161,7 +169,7 @@ impl PyContext {
161169
162170#[ pyfunction]
163171pub fn execute_partition (
164- plan : PyExecutionPlan ,
172+ plan : & Bound < ' _ , PyAny > ,
165173 part : usize ,
166174 inputs : PyObject ,
167175 py : Python ,
@@ -174,25 +182,25 @@ pub fn execute_partition(
174182}
175183
176184// TODO(@lsf) change this to use pickle
177- #[ pyfunction]
178- pub fn serialize_execution_plan ( plan : PyExecutionPlan ) -> PyResult < Vec < u8 > > {
179- let codec = ShuffleCodec { } ;
180- Ok ( physical_plan_to_bytes_with_extension_codec ( plan. plan , & codec) ?. to_vec ( ) )
181- }
185+ // #[pyfunction]
186+ // pub fn serialize_execution_plan(plan: Py<PyAny> ) -> PyResult<Vec<u8>> {
187+ // let codec = ShuffleCodec {};
188+ // Ok(physical_plan_to_bytes_with_extension_codec(plan.plan, &codec)?.to_vec())
189+ // }
182190
183- #[ pyfunction]
184- pub fn deserialize_execution_plan ( bytes : Vec < u8 > ) -> PyResult < PyExecutionPlan > {
185- let ctx = SessionContext :: new ( ) ;
186- let codec = ShuffleCodec { } ;
187- Ok ( PyExecutionPlan :: new (
188- physical_plan_from_bytes_with_extension_codec ( & bytes, & ctx, & codec) ?,
189- ) )
190- }
191+ // #[pyfunction]
192+ // pub fn deserialize_execution_plan(bytes: Vec<u8>) -> PyResult<PyExecutionPlan> {
193+ // let ctx = SessionContext::new();
194+ // let codec = ShuffleCodec {};
195+ // Ok(PyExecutionPlan::new(
196+ // physical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?,
197+ // ))
198+ // }
191199
192200/// Iterate down an ExecutionPlan and set the input objects for RayShuffleReaderExec.
193201fn _set_inputs_for_ray_shuffle_reader (
194202 plan : Arc < dyn ExecutionPlan > ,
195- input_partitions : & PyList ,
203+ input_partitions : & Bound < ' _ , PyList > ,
196204) -> Result < ( ) > {
197205 if let Some ( reader_exec) = plan. as_any ( ) . downcast_ref :: < RayShuffleReaderExec > ( ) {
198206 let exec_stage_id = reader_exec. stage_id ;
@@ -218,8 +226,8 @@ fn _set_inputs_for_ray_shuffle_reader(
218226 . map_err ( |e| DataFusionError :: Execution ( format ! ( "{}" , e) ) ) ?
219227 . extract :: < usize > ( )
220228 . map_err ( |e| DataFusionError :: Execution ( format ! ( "{}" , e) ) ) ?;
221- let batch = RecordBatch :: from_pyarrow (
222- pytuple
229+ let batch = RecordBatch :: from_pyarrow_bound (
230+ & pytuple
223231 . get_item ( 2 )
224232 . map_err ( |e| DataFusionError :: Execution ( format ! ( "{}" , e) ) ) ?,
225233 )
@@ -238,7 +246,7 @@ fn _set_inputs_for_ray_shuffle_reader(
238246/// write the results to disk, except for the final query stage, which will return the data.
239247/// inputs is a list of tuples of (stage_id, partition_id, bytes) for each input partition.
240248fn _execute_partition (
241- plan : PyExecutionPlan ,
249+ py_plan : & Bound < ' _ , PyAny > ,
242250 part : usize ,
243251 inputs : PyObject ,
244252) -> Result < Vec < RecordBatch > > {
@@ -251,19 +259,21 @@ fn _execute_partition(
251259 HashMap :: new ( ) ,
252260 Arc :: new ( RuntimeEnv :: default ( ) ) ,
253261 ) ) ;
262+
263+ let plan = execution_plan_from_pyany ( py_plan)
264+ . map_err ( |e| DataFusionError :: Execution ( e. to_string ( ) ) ) ?;
254265 Python :: with_gil ( |py| {
255266 let input_partitions = inputs
256- . as_ref ( py)
257- . downcast :: < PyList > ( )
267+ . downcast_bound :: < PyList > ( py)
258268 . map_err ( |e| DataFusionError :: Execution ( format ! ( "{}" , e) ) ) ?;
259- _set_inputs_for_ray_shuffle_reader ( plan. plan . clone ( ) , input_partitions)
269+ _set_inputs_for_ray_shuffle_reader ( plan. clone ( ) , input_partitions)
260270 } ) ?;
261271
262272 // create a Tokio runtime to run the async code
263273 let rt = Runtime :: new ( ) . unwrap ( ) ;
264274
265275 let fut: JoinHandle < Result < Vec < RecordBatch > > > = rt. spawn ( async move {
266- let mut stream = plan. plan . execute ( part, ctx) ?;
276+ let mut stream = plan. execute ( part, ctx) ?;
267277 let mut results = vec ! [ ] ;
268278 while let Some ( result) = stream. next ( ) . await {
269279 results. push ( result?) ;
0 commit comments