11use crate :: composed_extension_codec:: ComposedPhysicalExtensionCodec ;
2+ use crate :: context:: StageTaskContext ;
23use crate :: errors:: datafusion_error_to_tonic_status;
34use crate :: flight_service:: service:: ArrowFlightEndpoint ;
45use crate :: plan:: ArrowFlightReadExecProtoCodec ;
5- use crate :: stage_delegation:: { ActorContext , StageContext } ;
66use arrow_flight:: encode:: FlightDataEncoderBuilder ;
77use arrow_flight:: error:: FlightError ;
88use arrow_flight:: flight_service_server:: FlightService ;
99use arrow_flight:: Ticket ;
1010use datafusion:: error:: DataFusionError ;
1111use datafusion:: execution:: SessionStateBuilder ;
1212use datafusion:: optimizer:: OptimizerConfig ;
13- use datafusion:: physical_expr:: Partitioning ;
13+ use datafusion:: physical_expr:: { Partitioning , PhysicalExpr } ;
1414use datafusion:: physical_plan:: ExecutionPlan ;
15- use datafusion_proto:: physical_plan:: from_proto:: parse_protobuf_partitioning;
15+ use datafusion_proto:: physical_plan:: from_proto:: parse_physical_exprs;
16+ use datafusion_proto:: physical_plan:: to_proto:: serialize_physical_exprs;
1617use datafusion_proto:: physical_plan:: { AsExecutionPlan , PhysicalExtensionCodec } ;
17- use datafusion_proto:: protobuf:: PhysicalPlanNode ;
18+ use datafusion_proto:: protobuf:: { PhysicalExprNode , PhysicalPlanNode } ;
1819use futures:: TryStreamExt ;
1920use prost:: Message ;
2021use std:: sync:: Arc ;
2122use tonic:: { Request , Response , Status } ;
23+ use uuid:: Uuid ;
2224
2325#[ derive( Clone , PartialEq , :: prost:: Message ) ]
2426pub struct DoGet {
@@ -35,26 +37,38 @@ pub enum DoGetInner {
3537#[ derive( Clone , PartialEq , :: prost:: Message ) ]
3638pub struct RemotePlanExec {
3739 #[ prost( message, optional, boxed, tag = "1" ) ]
38- plan : Option < Box < PhysicalPlanNode > > ,
39- #[ prost( message, optional, tag = "2" ) ]
40- stage_context : Option < StageContext > ,
41- #[ prost( message, optional, tag = "3" ) ]
42- actor_context : Option < ActorContext > ,
40+ pub plan : Option < Box < PhysicalPlanNode > > ,
41+ #[ prost( string, tag = "2" ) ]
42+ pub stage_id : String ,
43+ #[ prost( uint64, tag = "3" ) ]
44+ pub task_idx : u64 ,
45+ #[ prost( uint64, tag = "4" ) ]
46+ pub output_task_idx : u64 ,
47+ #[ prost( uint64, tag = "5" ) ]
48+ pub output_tasks : u64 ,
49+ #[ prost( message, repeated, tag = "6" ) ]
50+ pub hash_expr : Vec < PhysicalExprNode > ,
4351}
4452
4553impl DoGet {
4654 pub fn new_remote_plan_exec_ticket (
4755 plan : Arc < dyn ExecutionPlan > ,
48- stage_context : StageContext ,
49- actor_context : ActorContext ,
56+ stage_id : Uuid ,
57+ task_idx : usize ,
58+ output_task_idx : usize ,
59+ output_tasks : usize ,
60+ hash_expr : & [ Arc < dyn PhysicalExpr > ] ,
5061 extension_codec : & dyn PhysicalExtensionCodec ,
5162 ) -> Result < Ticket , DataFusionError > {
5263 let node = PhysicalPlanNode :: try_from_physical_plan ( plan, extension_codec) ?;
5364 let do_get = Self {
5465 inner : Some ( DoGetInner :: RemotePlanExec ( RemotePlanExec {
5566 plan : Some ( Box :: new ( node) ) ,
56- stage_context : Some ( stage_context) ,
57- actor_context : Some ( actor_context) ,
67+ stage_id : stage_id. to_string ( ) ,
68+ task_idx : task_idx as u64 ,
69+ output_task_idx : output_task_idx as u64 ,
70+ output_tasks : output_tasks as u64 ,
71+ hash_expr : serialize_physical_exprs ( hash_expr, extension_codec) ?,
5872 } ) ) ,
5973 } ;
6074 Ok ( Ticket :: new ( do_get. encode_to_vec ( ) ) )
@@ -91,14 +105,6 @@ impl ArrowFlightEndpoint {
91105 return invalid_argument ( "RemotePlanExec is missing the plan" ) ;
92106 } ;
93107
94- let Some ( stage_context) = action. stage_context else {
95- return invalid_argument ( "RemotePlanExec is missing the stage context" ) ;
96- } ;
97-
98- let Some ( actor_context) = action. actor_context else {
99- return invalid_argument ( "RemotePlanExec is missing the actor context" ) ;
100- } ;
101-
102108 let mut codec = ComposedPhysicalExtensionCodec :: default ( ) ;
103109 codec. push ( ArrowFlightReadExecProtoCodec ) ;
104110 codec. push_from_config ( state. config ( ) ) ;
@@ -107,40 +113,34 @@ impl ArrowFlightEndpoint {
107113 . try_into_physical_plan ( function_registry, & self . runtime , & codec)
108114 . map_err ( |err| Status :: internal ( format ! ( "Cannot deserialize plan: {err}" ) ) ) ?;
109115
110- let stage_id = stage_context. id . clone ( ) ;
111- let caller_actor_idx = actor_context. caller_actor_idx as usize ;
112- let actor_idx = actor_context. actor_idx as usize ;
113- let prev_n = stage_context. prev_actors as usize ;
114- let partitioning = match parse_protobuf_partitioning (
115- stage_context. partitioning . as_ref ( ) ,
116+ let stage_id = Uuid :: parse_str ( & action. stage_id ) . map_err ( |err| {
117+ Status :: invalid_argument ( format ! (
118+ "Cannot parse stage id '{}': {err}" ,
119+ action. stage_id
120+ ) )
121+ } ) ?;
122+
123+ let task_idx = action. task_idx as usize ;
124+ let caller_actor_idx = action. output_task_idx as usize ;
125+ let prev_n = action. output_tasks as usize ;
126+ let partitioning = match parse_physical_exprs (
127+ & action. hash_expr ,
116128 function_registry,
117129 & plan. schema ( ) ,
118130 & codec,
119131 ) {
120- // We need to replace the partition count in the provided Partitioning scheme with
121- // the number of actors in the previous stage. ArrowFlightReadExec might be declaring
122- // N partitions, but each ArrowFlightReadExec::execute(n) call will go to a different
123- // actor in the next stage.
124- //
125- // Each actor in that next stage (us here) needs to expose as many partitioned streams
126- // as actors exist on its previous stage.
127- Ok ( Some ( partitioning) ) => match partitioning {
128- Partitioning :: RoundRobinBatch ( _) => Partitioning :: RoundRobinBatch ( prev_n) ,
129- Partitioning :: Hash ( expr, _) => Partitioning :: Hash ( expr, prev_n) ,
130- Partitioning :: UnknownPartitioning ( _) => Partitioning :: UnknownPartitioning ( prev_n) ,
131- } ,
132- Ok ( None ) => return invalid_argument ( "Missing partitioning" ) ,
133- Err ( err) => return invalid_argument ( format ! ( "Cannot parse partitioning {err}" ) ) ,
132+ Ok ( expr) if expr. is_empty ( ) => Partitioning :: Hash ( expr, prev_n) ,
133+ Ok ( _) => Partitioning :: RoundRobinBatch ( prev_n) ,
134+ Err ( err) => return invalid_argument ( format ! ( "Cannot parse hash expressions {err}" ) ) ,
134135 } ;
136+
135137 let config = state. config_mut ( ) ;
136- config. set_extension ( Arc :: clone ( & self . stage_delegation ) ) ;
137138 config. set_extension ( Arc :: clone ( & self . channel_manager ) ) ;
138- config. set_extension ( Arc :: new ( stage_context) ) ;
139- config. set_extension ( Arc :: new ( actor_context) ) ;
139+ config. set_extension ( Arc :: new ( StageTaskContext { task_idx } ) ) ;
140140
141141 let stream_partitioner = self
142142 . partitioner_registry
143- . get_or_create_stream_partitioner ( stage_id, actor_idx , plan, partitioning)
143+ . get_or_create_stream_partitioner ( stage_id, task_idx , plan, partitioning)
144144 . map_err ( |err| datafusion_error_to_tonic_status ( & err) ) ?;
145145
146146 let stream = stream_partitioner
0 commit comments