1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ use crate :: ext:: Extensions ;
1819use crate :: planner:: { make_execution_graph, PyExecutionGraph } ;
19- use crate :: shuffle:: { RayShuffleReaderExec , ShuffleCodec } ;
20+ use crate :: shuffle:: RayShuffleReaderExec ;
2021use datafusion:: arrow:: pyarrow:: FromPyArrow ;
2122use datafusion:: arrow:: pyarrow:: ToPyArrow ;
2223use datafusion:: arrow:: record_batch:: RecordBatch ;
@@ -27,6 +28,7 @@ use datafusion::physical_plan::{displayable, ExecutionPlan};
2728use datafusion:: prelude:: * ;
2829use datafusion_proto:: physical_plan:: AsExecutionPlan ;
2930use datafusion_proto:: protobuf;
31+ use datafusion_python:: physical_plan:: PyExecutionPlan ;
3032use futures:: StreamExt ;
3133use prost:: Message ;
3234use pyo3:: exceptions:: PyRuntimeError ;
@@ -46,22 +48,30 @@ pub struct PyContext {
4648
4749pub ( crate ) fn execution_plan_from_pyany (
4850 py_plan : & Bound < PyAny > ,
51+ py : Python ,
4952) -> PyResult < Arc < dyn ExecutionPlan > > {
50- let py_proto = py_plan. call_method0 ( "to_proto" ) ?;
51- let plan_bytes: & [ u8 ] = py_proto. extract ( ) ?;
52- let plan_node = protobuf:: PhysicalPlanNode :: try_decode ( plan_bytes) . map_err ( |e| {
53- PyRuntimeError :: new_err ( format ! (
54- "Unable to decode physical plan protobuf message: {}" ,
55- e
56- ) )
57- } ) ?;
53+ if let Ok ( py_plan) = py_plan. to_object ( py) . downcast_bound :: < PyExecutionPlan > ( py) {
54+ // For session contexts created with datafusion_ray.extended_session_context(), the inner
55+ // execution plan can be used as such (and the enabled extensions are all available).
56+ Ok ( py_plan. borrow ( ) . plan . clone ( ) )
57+ } else {
58+ // The session context originates from outside our library, so we'll grab the protobuf plan
59+ // by calling the python method with no extension codecs.
60+ let py_proto = py_plan. call_method0 ( "to_proto" ) ?;
61+ let plan_bytes: & [ u8 ] = py_proto. extract ( ) ?;
62+ let plan_node = protobuf:: PhysicalPlanNode :: try_decode ( plan_bytes) . map_err ( |e| {
63+ PyRuntimeError :: new_err ( format ! (
64+ "Unable to decode physical plan protobuf message: {}" ,
65+ e
66+ ) )
67+ } ) ?;
5868
59- let codec = ShuffleCodec { } ;
60- let runtime = RuntimeEnv :: default ( ) ;
61- let registry = SessionContext :: new ( ) ;
62- plan_node
63- . try_into_physical_plan ( & registry , & runtime , & codec )
64- . map_err ( |e| e . into ( ) )
69+ let runtime = RuntimeEnv :: default ( ) ;
70+ let registry = SessionContext :: new ( ) ;
71+ plan_node
72+ . try_into_physical_plan ( & registry , & runtime , Extensions :: codec ( ) )
73+ . map_err ( |e| e . into ( ) )
74+ }
6575}
6676
6777#[ pymethods]
@@ -88,14 +98,14 @@ impl PyContext {
8898 }
8999
90100 /// Plan a distributed SELECT query for executing against the Ray workers
91- pub fn plan ( & self , plan : & Bound < PyAny > ) -> PyResult < PyExecutionGraph > {
101+ pub fn plan ( & self , plan : & Bound < PyAny > , py : Python ) -> PyResult < PyExecutionGraph > {
92102 // println!("Planning {}", sql);
93103 // let df = wait_for_future(py, self.ctx.sql(sql))?;
94104 // let py_df = self.run_sql(sql, py)?;
95105 // let py_plan = py_df.call_method0(py, "execution_plan")?;
96106 // let py_plan = py_plan.bind(py);
97107
98- let plan = execution_plan_from_pyany ( plan) ?;
108+ let plan = execution_plan_from_pyany ( plan, py ) ?;
99109 let graph = make_execution_graph ( plan. clone ( ) ) ?;
100110
101111 // debug logging
@@ -143,9 +153,10 @@ pub fn serialize_execution_plan(
143153 plan : Arc < dyn ExecutionPlan > ,
144154 py : Python ,
145155) -> PyResult < Bound < ' _ , PyBytes > > {
146- let codec = ShuffleCodec { } ;
147- let proto =
148- datafusion_proto:: protobuf:: PhysicalPlanNode :: try_from_physical_plan ( plan. clone ( ) , & codec) ?;
156+ let proto = datafusion_proto:: protobuf:: PhysicalPlanNode :: try_from_physical_plan (
157+ plan. clone ( ) ,
158+ Extensions :: codec ( ) ,
159+ ) ?;
149160
150161 let bytes = proto. encode_to_vec ( ) ;
151162 Ok ( PyBytes :: new_bound ( py, & bytes) )
@@ -162,9 +173,8 @@ pub fn deserialize_execution_plan(proto_msg: &Bound<PyBytes>) -> PyResult<Arc<dy
162173 } ) ?;
163174
164175 let ctx = SessionContext :: new ( ) ;
165- let codec = ShuffleCodec { } ;
166176 let plan = proto_plan
167- . try_into_physical_plan ( & ctx, & ctx. runtime_env ( ) , & codec)
177+ . try_into_physical_plan ( & ctx, & ctx. runtime_env ( ) , Extensions :: codec ( ) )
168178 . map_err ( DataFusionError :: from) ?;
169179
170180 Ok ( plan)
0 commit comments