1- use crate :: composed_extension_codec:: ComposedPhysicalExtensionCodec ;
2- use crate :: context:: StageTaskContext ;
31use crate :: errors:: datafusion_error_to_tonic_status;
42use crate :: flight_service:: service:: ArrowFlightEndpoint ;
5- use crate :: plan:: ArrowFlightReadExecProtoCodec ;
3+ use crate :: plan:: DistributedCodec ;
4+ use crate :: stage:: { stage_from_proto, ExecutionStageProto } ;
65use arrow_flight:: encode:: FlightDataEncoderBuilder ;
76use arrow_flight:: error:: FlightError ;
87use arrow_flight:: flight_service_server:: FlightService ;
98use arrow_flight:: Ticket ;
10- use datafusion:: error:: DataFusionError ;
119use datafusion:: execution:: SessionStateBuilder ;
1210use datafusion:: optimizer:: OptimizerConfig ;
13- use datafusion:: physical_expr:: { Partitioning , PhysicalExpr } ;
1411use datafusion:: physical_plan:: ExecutionPlan ;
15- use datafusion_proto:: physical_plan:: from_proto:: parse_physical_exprs;
16- use datafusion_proto:: physical_plan:: to_proto:: serialize_physical_exprs;
17- use datafusion_proto:: physical_plan:: { AsExecutionPlan , PhysicalExtensionCodec } ;
18- use datafusion_proto:: protobuf:: { PhysicalExprNode , PhysicalPlanNode } ;
12+ use datafusion_proto:: physical_plan:: PhysicalExtensionCodec ;
1913use futures:: TryStreamExt ;
2014use prost:: Message ;
2115use std:: sync:: Arc ;
2216use tonic:: { Request , Response , Status } ;
23- use uuid:: Uuid ;
2417
2518#[ derive( Clone , PartialEq , :: prost:: Message ) ]
2619pub struct DoGet {
27- #[ prost( oneof = "DoGetInner" , tags = "1" ) ]
28- pub inner : Option < DoGetInner > ,
29- }
30-
31- #[ derive( Clone , PartialEq , prost:: Oneof ) ]
32- pub enum DoGetInner {
33- #[ prost( message, tag = "1" ) ]
34- RemotePlanExec ( RemotePlanExec ) ,
35- }
36-
37- #[ derive( Clone , PartialEq , :: prost:: Message ) ]
38- pub struct RemotePlanExec {
39- #[ prost( message, optional, boxed, tag = "1" ) ]
40- pub plan : Option < Box < PhysicalPlanNode > > ,
41- #[ prost( string, tag = "2" ) ]
42- pub stage_id : String ,
43- #[ prost( uint64, tag = "3" ) ]
44- pub task_idx : u64 ,
45- #[ prost( uint64, tag = "4" ) ]
46- pub output_task_idx : u64 ,
47- #[ prost( uint64, tag = "5" ) ]
48- pub output_tasks : u64 ,
49- #[ prost( message, repeated, tag = "6" ) ]
50- pub hash_expr : Vec < PhysicalExprNode > ,
51- }
52-
53- impl DoGet {
54- pub fn new_remote_plan_exec_ticket (
55- plan : Arc < dyn ExecutionPlan > ,
56- stage_id : Uuid ,
57- task_idx : usize ,
58- output_task_idx : usize ,
59- output_tasks : usize ,
60- hash_expr : & [ Arc < dyn PhysicalExpr > ] ,
61- extension_codec : & dyn PhysicalExtensionCodec ,
62- ) -> Result < Ticket , DataFusionError > {
63- let node = PhysicalPlanNode :: try_from_physical_plan ( plan, extension_codec) ?;
64- let do_get = Self {
65- inner : Some ( DoGetInner :: RemotePlanExec ( RemotePlanExec {
66- plan : Some ( Box :: new ( node) ) ,
67- stage_id : stage_id. to_string ( ) ,
68- task_idx : task_idx as u64 ,
69- output_task_idx : output_task_idx as u64 ,
70- output_tasks : output_tasks as u64 ,
71- hash_expr : serialize_physical_exprs ( hash_expr, extension_codec) ?,
72- } ) ) ,
73- } ;
74- Ok ( Ticket :: new ( do_get. encode_to_vec ( ) ) )
75- }
20+ /// The ExecutionStage that we are going to execute
21+ #[ prost( message, optional, tag = "1" ) ]
22+ pub stage_proto : Option < ExecutionStageProto > ,
23+ /// the partition of the stage to execute
24+ #[ prost( uint64, tag = "2" ) ]
25+ pub partition : u64 ,
7626}
7727
7828impl ArrowFlightEndpoint {
@@ -81,74 +31,43 @@ impl ArrowFlightEndpoint {
8131 request : Request < Ticket > ,
8232 ) -> Result < Response < <ArrowFlightEndpoint as FlightService >:: DoGetStream > , Status > {
8333 let Ticket { ticket } = request. into_inner ( ) ;
84- let action = DoGet :: decode ( ticket) . map_err ( |err| {
34+ let doget = DoGet :: decode ( ticket) . map_err ( |err| {
8535 Status :: invalid_argument ( format ! ( "Cannot decode DoGet message: {err}" ) )
8636 } ) ?;
8737
88- let Some ( action) = action. inner else {
89- return invalid_argument ( "DoGet message is empty" ) ;
90- } ;
91-
92- let DoGetInner :: RemotePlanExec ( action) = action;
38+ let stage_msg = doget
39+ . stage_proto
40+ . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage proto" ) ) ?;
9341
9442 let state_builder = SessionStateBuilder :: new ( )
9543 . with_runtime_env ( Arc :: clone ( & self . runtime ) )
9644 . with_default_features ( ) ;
9745
9846 let mut state = self . session_builder . on_new_session ( state_builder) . build ( ) ;
9947
100- let Some ( function_registry) = state. function_registry ( ) else {
101- return invalid_argument ( "FunctionRegistry not present in newly built SessionState" ) ;
102- } ;
48+ let function_registry = state. function_registry ( ) . ok_or ( Status :: invalid_argument (
49+ "FunctionRegistry not present in newly built SessionState" ,
50+ ) ) ? ;
10351
104- let Some ( plan_proto) = action. plan else {
105- return invalid_argument ( "RemotePlanExec is missing the plan" ) ;
106- } ;
52+ let codec = DistributedCodec { } ;
53+ let codec = Arc :: new ( codec) as Arc < dyn PhysicalExtensionCodec > ;
10754
108- let mut codec = ComposedPhysicalExtensionCodec :: default ( ) ;
109- codec. push ( ArrowFlightReadExecProtoCodec ) ;
110- codec. push_from_config ( state. config ( ) ) ;
111-
112- let plan = plan_proto
113- . try_into_physical_plan ( function_registry, & self . runtime , & codec)
114- . map_err ( |err| Status :: internal ( format ! ( "Cannot deserialize plan: {err}" ) ) ) ?;
115-
116- let stage_id = Uuid :: parse_str ( & action. stage_id ) . map_err ( |err| {
117- Status :: invalid_argument ( format ! (
118- "Cannot parse stage id '{}': {err}" ,
119- action. stage_id
120- ) )
121- } ) ?;
122-
123- let task_idx = action. task_idx as usize ;
124- let caller_actor_idx = action. output_task_idx as usize ;
125- let prev_n = action. output_tasks as usize ;
126- let partitioning = match parse_physical_exprs (
127- & action. hash_expr ,
128- function_registry,
129- & plan. schema ( ) ,
130- & codec,
131- ) {
132- Ok ( expr) if expr. is_empty ( ) => Partitioning :: Hash ( expr, prev_n) ,
133- Ok ( _) => Partitioning :: RoundRobinBatch ( prev_n) ,
134- Err ( err) => return invalid_argument ( format ! ( "Cannot parse hash expressions {err}" ) ) ,
135- } ;
55+ let stage = stage_from_proto ( stage_msg, function_registry, & self . runtime . as_ref ( ) , codec)
56+ . map ( Arc :: new)
57+ . map_err ( |err| Status :: invalid_argument ( format ! ( "Cannot decode stage proto: {err}" ) ) ) ?;
13658
59+ // Add the extensions that might be required for ExecutionPlan nodes in the plan
13760 let config = state. config_mut ( ) ;
13861 config. set_extension ( Arc :: clone ( & self . channel_manager ) ) ;
139- config. set_extension ( Arc :: new ( StageTaskContext { task_idx } ) ) ;
140-
141- let stream_partitioner = self
142- . partitioner_registry
143- . get_or_create_stream_partitioner ( stage_id, task_idx, plan, partitioning)
144- . map_err ( |err| datafusion_error_to_tonic_status ( & err) ) ?;
62+ config. set_extension ( stage. clone ( ) ) ;
14563
146- let stream = stream_partitioner
147- . execute ( caller_actor_idx, state. task_ctx ( ) )
148- . map_err ( |err| datafusion_error_to_tonic_status ( & err) ) ?;
64+ let stream = stage
65+ . plan
66+ . execute ( doget. partition as usize , state. task_ctx ( ) )
67+ . map_err ( |err| Status :: internal ( format ! ( "Error executing stage plan: {err:#?}" ) ) ) ?;
14968
15069 let flight_data_stream = FlightDataEncoderBuilder :: new ( )
151- . with_schema ( stream_partitioner . schema ( ) )
70+ . with_schema ( stage . plan . schema ( ) . clone ( ) )
15271 . build ( stream. map_err ( |err| {
15372 FlightError :: Tonic ( Box :: new ( datafusion_error_to_tonic_status ( & err) ) )
15473 } ) ) ;
0 commit comments