1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ use crate :: ext:: Extensions ;
1819use crate :: planner:: { make_execution_graph, PyExecutionGraph } ;
19- use crate :: shuffle:: ShuffleCodec ;
2020use datafusion:: arrow:: pyarrow:: ToPyArrow ;
2121use datafusion:: arrow:: record_batch:: RecordBatch ;
2222use datafusion:: error:: { DataFusionError , Result } ;
@@ -26,6 +26,7 @@ use datafusion::physical_plan::{displayable, ExecutionPlan};
2626use datafusion:: prelude:: * ;
2727use datafusion_proto:: physical_plan:: AsExecutionPlan ;
2828use datafusion_proto:: protobuf;
29+ use datafusion_python:: physical_plan:: PyExecutionPlan ;
2930use futures:: StreamExt ;
3031use prost:: Message ;
3132use pyo3:: exceptions:: PyRuntimeError ;
@@ -45,22 +46,30 @@ pub struct PyContext {
4546
4647pub ( crate ) fn execution_plan_from_pyany (
4748 py_plan : & Bound < PyAny > ,
49+ py : Python ,
4850) -> PyResult < Arc < dyn ExecutionPlan > > {
49- let py_proto = py_plan. call_method0 ( "to_proto" ) ?;
50- let plan_bytes: & [ u8 ] = py_proto. extract ( ) ?;
51- let plan_node = protobuf:: PhysicalPlanNode :: try_decode ( plan_bytes) . map_err ( |e| {
52- PyRuntimeError :: new_err ( format ! (
53- "Unable to decode physical plan protobuf message: {}" ,
54- e
55- ) )
56- } ) ?;
57-
58- let codec = ShuffleCodec { } ;
59- let runtime = RuntimeEnv :: default ( ) ;
60- let registry = SessionContext :: new ( ) ;
61- plan_node
62- . try_into_physical_plan ( & registry, & runtime, & codec)
63- . map_err ( |e| e. into ( ) )
51+ if let Ok ( py_plan) = py_plan. to_object ( py) . downcast_bound :: < PyExecutionPlan > ( py) {
52+ // For session contexts created with datafusion_ray.extended_session_context(), the inner
53+ // execution plan can be used as such (and the enabled extensions are all available).
54+ Ok ( py_plan. borrow ( ) . plan . clone ( ) )
55+ } else {
56+ // The session context originates from outside our library, so we'll grab the protobuf plan
57+ // by calling the python method with no extension codecs.
58+ let py_proto = py_plan. call_method0 ( "to_proto" ) ?;
59+ let plan_bytes: & [ u8 ] = py_proto. extract ( ) ?;
60+ let plan_node = protobuf:: PhysicalPlanNode :: try_decode ( plan_bytes) . map_err ( |e| {
61+ PyRuntimeError :: new_err ( format ! (
62+ "Unable to decode physical plan protobuf message: {}" ,
63+ e
64+ ) )
65+ } ) ?;
66+
67+ let runtime = RuntimeEnv :: default ( ) ;
68+ let registry = SessionContext :: new ( ) ;
69+ plan_node
70+ . try_into_physical_plan ( & registry, & runtime, Extensions :: codec ( ) )
71+ . map_err ( |e| e. into ( ) )
72+ }
6473}
6574
6675#[ pymethods]
@@ -87,14 +96,14 @@ impl PyContext {
8796 }
8897
8998 /// Plan a distributed SELECT query for executing against the Ray workers
90- pub fn plan ( & self , plan : & Bound < PyAny > ) -> PyResult < PyExecutionGraph > {
99+ pub fn plan ( & self , plan : & Bound < PyAny > , py : Python ) -> PyResult < PyExecutionGraph > {
91100 // println!("Planning {}", sql);
92101 // let df = wait_for_future(py, self.ctx.sql(sql))?;
93102 // let py_df = self.run_sql(sql, py)?;
94103 // let py_plan = py_df.call_method0(py, "execution_plan")?;
95104 // let py_plan = py_plan.bind(py);
96105
97- let plan = execution_plan_from_pyany ( plan) ?;
106+ let plan = execution_plan_from_pyany ( plan, py ) ?;
98107 let graph = make_execution_graph ( plan. clone ( ) ) ?;
99108
100109 // debug logging
@@ -140,9 +149,10 @@ pub fn serialize_execution_plan(
140149 plan : Arc < dyn ExecutionPlan > ,
141150 py : Python ,
142151) -> PyResult < Bound < ' _ , PyBytes > > {
143- let codec = ShuffleCodec { } ;
144- let proto =
145- datafusion_proto:: protobuf:: PhysicalPlanNode :: try_from_physical_plan ( plan. clone ( ) , & codec) ?;
152+ let proto = datafusion_proto:: protobuf:: PhysicalPlanNode :: try_from_physical_plan (
153+ plan. clone ( ) ,
154+ Extensions :: codec ( ) ,
155+ ) ?;
146156
147157 let bytes = proto. encode_to_vec ( ) ;
148158 Ok ( PyBytes :: new_bound ( py, & bytes) )
@@ -159,9 +169,8 @@ pub fn deserialize_execution_plan(proto_msg: &Bound<PyBytes>) -> PyResult<Arc<dy
159169 } ) ?;
160170
161171 let ctx = SessionContext :: new ( ) ;
162- let codec = ShuffleCodec { } ;
163172 let plan = proto_plan
164- . try_into_physical_plan ( & ctx, & ctx. runtime_env ( ) , & codec)
173+ . try_into_physical_plan ( & ctx, & ctx. runtime_env ( ) , Extensions :: codec ( ) )
165174 . map_err ( DataFusionError :: from) ?;
166175
167176 Ok ( plan)
0 commit comments