@@ -36,7 +36,6 @@ use datafusion_proto::bytes::{
3636} ;
3737use datafusion_proto:: physical_plan:: { AsExecutionPlan , DefaultPhysicalExtensionCodec } ;
3838use datafusion_proto:: protobuf;
39- use datafusion_python:: physical_plan:: PyExecutionPlan ;
4039use futures:: StreamExt ;
4140use prost:: { DecodeError , Message } ;
4241use pyo3:: exceptions:: PyRuntimeError ;
@@ -54,6 +53,26 @@ pub struct PyContext {
5453 pub ( crate ) py_ctx : PyObject ,
5554}
5655
56+ pub ( crate ) fn execution_plan_from_pyany (
57+ py_plan : & Bound < PyAny > ,
58+ ) -> PyResult < Arc < dyn ExecutionPlan > > {
59+ let py_proto = py_plan. call_method0 ( "to_proto" ) ?;
60+ let plan_bytes: & [ u8 ] = py_proto. extract ( ) ?;
61+ let plan_node = protobuf:: PhysicalPlanNode :: try_decode ( plan_bytes) . map_err ( |e| {
62+ PyRuntimeError :: new_err ( format ! (
63+ "Unable to decode physical plan protobuf message: {}" ,
64+ e
65+ ) )
66+ } ) ?;
67+
68+ let codec = DefaultPhysicalExtensionCodec { } ;
69+ let runtime = RuntimeEnv :: default ( ) ;
70+ let registry = SessionContext :: new ( ) ;
71+ plan_node
72+ . try_into_physical_plan ( & registry, & runtime, & codec)
73+ . map_err ( |e| e. into ( ) )
74+ }
75+
5776#[ pymethods]
5877impl PyContext {
5978 #[ new]
@@ -114,20 +133,9 @@ impl PyContext {
114133 // let df = wait_for_future(py, self.ctx.sql(sql))?;
115134 let py_df = self . run_sql ( sql, py) ?;
116135 let py_plan = py_df. call_method0 ( py, "execution_plan" ) ?;
117- let py_proto = py_plan. call_method0 ( py, "to_proto" ) ?;
118- let plan_bytes: & [ u8 ] = py_proto. extract ( py) ?;
119- let plan_node = protobuf:: PhysicalPlanNode :: decode ( plan_bytes) . map_err ( |e| {
120- PyRuntimeError :: new_err ( format ! (
121- "Unable to decode physical plan protobuf message: {}" ,
122- e
123- ) )
124- } ) ?;
125-
126- let codec = DefaultPhysicalExtensionCodec { } ;
127- let runtime = RuntimeEnv :: default ( ) ;
128- let registry = SessionContext :: new ( ) ;
129- let plan = plan_node. try_into_physical_plan ( & registry, & runtime, & codec) ?;
136+ let py_plan = py_plan. bind ( py) ;
130137
138+ let plan = execution_plan_from_pyany ( py_plan) ?;
131139 let graph = make_execution_graph ( plan. clone ( ) ) ?;
132140
133141 // debug logging
@@ -147,7 +155,7 @@ impl PyContext {
147155 /// Execute a partition of a query plan. This will typically be executing a shuffle write and write the results to disk
148156 pub fn execute_partition (
149157 & self ,
150- plan : PyExecutionPlan ,
158+ plan : & Bound < ' _ , PyAny > ,
151159 part : usize ,
152160 inputs : PyObject ,
153161 py : Python ,
@@ -158,7 +166,7 @@ impl PyContext {
158166
159167#[ pyfunction]
160168pub fn execute_partition (
161- plan : PyExecutionPlan ,
169+ plan : & Bound < ' _ , PyAny > ,
162170 part : usize ,
163171 inputs : PyObject ,
164172 py : Python ,
@@ -171,20 +179,20 @@ pub fn execute_partition(
171179}
172180
173181// TODO(@lsf) change this to use pickle
174- #[ pyfunction]
175- pub fn serialize_execution_plan ( plan : PyExecutionPlan ) -> PyResult < Vec < u8 > > {
176- let codec = ShuffleCodec { } ;
177- Ok ( physical_plan_to_bytes_with_extension_codec ( plan. plan , & codec) ?. to_vec ( ) )
178- }
182+ // #[pyfunction]
183+ // pub fn serialize_execution_plan(plan: Py<PyAny> ) -> PyResult<Vec<u8>> {
184+ // let codec = ShuffleCodec {};
185+ // Ok(physical_plan_to_bytes_with_extension_codec(plan.plan, &codec)?.to_vec())
186+ // }
179187
180- #[ pyfunction]
181- pub fn deserialize_execution_plan ( bytes : Vec < u8 > ) -> PyResult < PyExecutionPlan > {
182- let ctx = SessionContext :: new ( ) ;
183- let codec = ShuffleCodec { } ;
184- Ok ( PyExecutionPlan :: new (
185- physical_plan_from_bytes_with_extension_codec ( & bytes, & ctx, & codec) ?,
186- ) )
187- }
188+ // #[pyfunction]
189+ // pub fn deserialize_execution_plan(bytes: Vec<u8>) -> PyResult<PyExecutionPlan> {
190+ // let ctx = SessionContext::new();
191+ // let codec = ShuffleCodec {};
192+ // Ok(PyExecutionPlan::new(
193+ // physical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?,
194+ // ))
195+ // }
188196
189197/// Iterate down an ExecutionPlan and set the input objects for RayShuffleReaderExec.
190198fn _set_inputs_for_ray_shuffle_reader (
@@ -235,7 +243,7 @@ fn _set_inputs_for_ray_shuffle_reader(
235243/// write the results to disk, except for the final query stage, which will return the data.
236244/// inputs is a list of tuples of (stage_id, partition_id, bytes) for each input partition.
237245fn _execute_partition (
238- plan : PyExecutionPlan ,
246+ py_plan : & Bound < ' _ , PyAny > ,
239247 part : usize ,
240248 inputs : PyObject ,
241249) -> Result < Vec < RecordBatch > > {
@@ -248,19 +256,21 @@ fn _execute_partition(
248256 HashMap :: new ( ) ,
249257 Arc :: new ( RuntimeEnv :: default ( ) ) ,
250258 ) ) ;
259+
260+ let plan = execution_plan_from_pyany ( py_plan)
261+ . map_err ( |e| DataFusionError :: Execution ( e. to_string ( ) ) ) ?;
251262 Python :: with_gil ( |py| {
252263 let input_partitions = inputs
253- . bind ( py)
254- . downcast :: < PyList > ( )
264+ . downcast_bound :: < PyList > ( py)
255265 . map_err ( |e| DataFusionError :: Execution ( format ! ( "{}" , e) ) ) ?;
256- _set_inputs_for_ray_shuffle_reader ( plan. plan . clone ( ) , input_partitions)
266+ _set_inputs_for_ray_shuffle_reader ( plan. clone ( ) , input_partitions)
257267 } ) ?;
258268
259269 // create a Tokio runtime to run the async code
260270 let rt = Runtime :: new ( ) . unwrap ( ) ;
261271
262272 let fut: JoinHandle < Result < Vec < RecordBatch > > > = rt. spawn ( async move {
263- let mut stream = plan. plan . execute ( part, ctx) ?;
273+ let mut stream = plan. execute ( part, ctx) ?;
264274 let mut results = vec ! [ ] ;
265275 while let Some ( result) = stream. next ( ) . await {
266276 results. push ( result?) ;
0 commit comments