@@ -20,14 +20,13 @@ use crate::protobuf::{RayShuffleReaderExecNode, RayShuffleWriterExecNode, RaySql
2020use crate :: shuffle:: { RayShuffleReaderExec , RayShuffleWriterExec } ;
2121use datafusion:: arrow:: datatypes:: SchemaRef ;
2222use datafusion:: common:: { DataFusionError , Result } ;
23- use datafusion:: execution:: runtime_env:: RuntimeEnv ;
2423use datafusion:: execution:: FunctionRegistry ;
2524use datafusion:: physical_plan:: { ExecutionPlan , Partitioning } ;
2625use datafusion_proto:: physical_plan:: from_proto:: parse_protobuf_hash_partitioning;
2726use datafusion_proto:: physical_plan:: to_proto:: serialize_physical_expr;
27+ use datafusion_proto:: physical_plan:: DefaultPhysicalExtensionCodec ;
2828use datafusion_proto:: physical_plan:: PhysicalExtensionCodec ;
29- use datafusion_proto:: physical_plan:: { AsExecutionPlan , DefaultPhysicalExtensionCodec } ;
30- use datafusion_proto:: protobuf:: { self , PhysicalHashRepartition , PhysicalPlanNode } ;
29+ use datafusion_proto:: protobuf:: { self , PhysicalHashRepartition } ;
3130use prost:: Message ;
3231use std:: sync:: Arc ;
3332
@@ -38,48 +37,62 @@ impl PhysicalExtensionCodec for ShuffleCodec {
3837 fn try_decode (
3938 & self ,
4039 buf : & [ u8 ] ,
41- _inputs : & [ Arc < dyn ExecutionPlan > ] ,
40+ inputs : & [ Arc < dyn ExecutionPlan > ] ,
4241 registry : & dyn FunctionRegistry ,
4342 ) -> Result < Arc < dyn ExecutionPlan > , DataFusionError > {
4443 // decode bytes to protobuf struct
4544 let node = RaySqlExecNode :: decode ( buf)
4645 . map_err ( |e| DataFusionError :: Internal ( format ! ( "failed to decode plan: {e:?}" ) ) ) ?;
4746 let extension_codec = DefaultPhysicalExtensionCodec { } ;
48- match node. plan_type {
49- Some ( PlanType :: RayShuffleReader ( reader) ) => {
50- let schema = reader. schema . as_ref ( ) . unwrap ( ) ;
51- let schema: SchemaRef = Arc :: new ( schema. try_into ( ) . unwrap ( ) ) ;
52- let hash_part = parse_protobuf_hash_partitioning (
53- reader. partitioning . as_ref ( ) ,
54- registry,
55- & schema,
56- & extension_codec,
57- ) ?;
58- Ok ( Arc :: new ( RayShuffleReaderExec :: new (
59- reader. stage_id as usize ,
60- schema,
61- hash_part. unwrap ( ) ,
62- ) ) )
47+ if let Some ( plan_type) = node. plan_type {
48+ match plan_type {
49+ PlanType :: RayShuffleReader ( reader) => {
50+ let schema = reader. schema . as_ref ( ) . ok_or_else ( || {
51+ DataFusionError :: Execution ( "invalid encoded schema" . into ( ) )
52+ } ) ?;
53+ let schema: SchemaRef = Arc :: new ( schema. try_into ( ) ?) ;
54+ let hash_part = parse_protobuf_hash_partitioning (
55+ reader. partitioning . as_ref ( ) ,
56+ registry,
57+ & schema,
58+ & extension_codec,
59+ ) ?
60+ . ok_or_else ( || {
61+ DataFusionError :: Execution ( "missing partitioning info" . into ( ) )
62+ } ) ?;
63+ Ok ( Arc :: new ( RayShuffleReaderExec :: new (
64+ reader. stage_id as usize ,
65+ schema,
66+ hash_part,
67+ ) ) )
68+ }
69+ PlanType :: RayShuffleWriter ( writer) => {
70+ let plan = inputs
71+ . first ( )
72+ . ok_or_else ( || {
73+ DataFusionError :: Execution ( "No inputs for shuffle writer" . into ( ) )
74+ } ) ?
75+ . to_owned ( ) ;
76+ let hash_part = parse_protobuf_hash_partitioning (
77+ writer. partitioning . as_ref ( ) ,
78+ registry,
79+ plan. schema ( ) . as_ref ( ) ,
80+ & extension_codec,
81+ ) ?
82+ . ok_or_else ( || {
83+ DataFusionError :: Execution ( "missing partitioning info" . into ( ) )
84+ } ) ?;
85+ Ok ( Arc :: new ( RayShuffleWriterExec :: new (
86+ writer. stage_id as usize ,
87+ plan,
88+ hash_part,
89+ ) ) )
90+ }
6391 }
64- Some ( PlanType :: RayShuffleWriter ( writer) ) => {
65- let plan = writer. plan . unwrap ( ) . try_into_physical_plan (
66- registry,
67- & RuntimeEnv :: default ( ) ,
68- self ,
69- ) ?;
70- let hash_part = parse_protobuf_hash_partitioning (
71- writer. partitioning . as_ref ( ) ,
72- registry,
73- plan. schema ( ) . as_ref ( ) ,
74- & extension_codec,
75- ) ?;
76- Ok ( Arc :: new ( RayShuffleWriterExec :: new (
77- writer. stage_id as usize ,
78- plan,
79- hash_part. unwrap ( ) ,
80- ) ) )
81- }
82- _ => unreachable ! ( ) ,
92+ } else {
93+ Err ( DataFusionError :: Execution (
94+ "RaySqlExecNode with no plan_type" . into ( ) ,
95+ ) )
8396 }
8497 }
8598
@@ -88,7 +101,7 @@ impl PhysicalExtensionCodec for ShuffleCodec {
88101 node : Arc < dyn ExecutionPlan > ,
89102 buf : & mut Vec < u8 > ,
90103 ) -> Result < ( ) , DataFusionError > {
91- let plan = if let Some ( reader) = node. as_any ( ) . downcast_ref :: < RayShuffleReaderExec > ( ) {
104+ if let Some ( reader) = node. as_any ( ) . downcast_ref :: < RayShuffleReaderExec > ( ) {
92105 let schema: protobuf:: Schema = reader. schema ( ) . try_into ( ) . unwrap ( ) ;
93106 let partitioning =
94107 encode_partitioning_scheme ( reader. properties ( ) . output_partitioning ( ) ) ?;
@@ -97,22 +110,27 @@ impl PhysicalExtensionCodec for ShuffleCodec {
97110 schema : Some ( schema) ,
98111 partitioning : Some ( partitioning) ,
99112 } ;
100- PlanType :: RayShuffleReader ( reader)
113+ PlanType :: RayShuffleReader ( reader) . encode ( buf) ;
114+ Ok ( ( ) )
101115 } else if let Some ( writer) = node. as_any ( ) . downcast_ref :: < RayShuffleWriterExec > ( ) {
102- let plan = PhysicalPlanNode :: try_from_physical_plan ( writer. plan . clone ( ) , self ) ?;
103116 let partitioning =
104117 encode_partitioning_scheme ( writer. properties ( ) . output_partitioning ( ) ) ?;
105118 let writer = RayShuffleWriterExecNode {
106119 stage_id : writer. stage_id as u32 ,
107- plan : Some ( plan) ,
120+ // No need to redundantly serialize the child plan, as input plan(s) are recursively
121+ // serialized by PhysicalPlanNode and will be available as `inputs` in `try_decode`.
122+ // TODO: remove this field from the proto definition?
123+ plan : None ,
108124 partitioning : Some ( partitioning) ,
109125 } ;
110- PlanType :: RayShuffleWriter ( writer)
126+ PlanType :: RayShuffleWriter ( writer) . encode ( buf) ;
127+ Ok ( ( ) )
111128 } else {
112- unreachable ! ( )
113- } ;
114- plan. encode ( buf) ;
115- Ok ( ( ) )
129+ Err ( DataFusionError :: Execution ( format ! (
130+ "Unsupported plan node: {}" ,
131+ node. name( )
132+ ) ) )
133+ }
116134 }
117135}
118136
0 commit comments