11use crate :: composed_extension_codec:: ComposedPhysicalExtensionCodec ;
22use crate :: errors:: datafusion_error_to_tonic_status;
33use crate :: flight_service:: service:: ArrowFlightEndpoint ;
4- use crate :: plan:: DistributedCodec ;
5- use crate :: stage:: { stage_from_proto, ExecutionStageProto } ;
4+ use crate :: plan:: { DistributedCodec , PartitionGroup } ;
5+ use crate :: stage:: { stage_from_proto, ExecutionStage , ExecutionStageProto , StageKey } ;
66use crate :: user_provided_codec:: get_user_codec;
77use arrow_flight:: encode:: FlightDataEncoderBuilder ;
88use arrow_flight:: error:: FlightError ;
99use arrow_flight:: flight_service_server:: FlightService ;
1010use arrow_flight:: Ticket ;
11- use datafusion:: execution:: SessionStateBuilder ;
11+ use datafusion:: execution:: { SessionState , SessionStateBuilder } ;
1212use datafusion:: optimizer:: OptimizerConfig ;
13+ use datafusion:: physical_plan:: ExecutionPlan ;
1314use futures:: TryStreamExt ;
1415use prost:: Message ;
1516use std:: sync:: Arc ;
17+ use tokio:: sync:: OnceCell ;
1618use tonic:: { Request , Response , Status } ;
1719
1820#[ derive( Clone , PartialEq , :: prost:: Message ) ]
1921pub struct DoGet {
2022 /// The ExecutionStage that we are going to execute
2123 #[ prost( message, optional, tag = "1" ) ]
2224 pub stage_proto : Option < ExecutionStageProto > ,
23- /// the partition of the stage to execute
25+ /// The index to the task within the stage that we want to execute
2426 #[ prost( uint64, tag = "2" ) ]
27+ pub task_number : u64 ,
28+ /// the partition number we want to execute
29+ #[ prost( uint64, tag = "3" ) ]
2530 pub partition : u64 ,
31+ /// The stage key that identifies the stage. This is useful to keep
32+ /// outside of the stage proto as it is used to store the stage
33+ /// and we may not need to deserialize the entire stage proto
34+ /// if we already have stored it
35+ #[ prost( message, optional, tag = "4" ) ]
36+ pub stage_key : Option < StageKey > ,
2637}
2738
2839impl ArrowFlightEndpoint {
@@ -35,42 +46,28 @@ impl ArrowFlightEndpoint {
3546 Status :: invalid_argument ( format ! ( "Cannot decode DoGet message: {err}" ) )
3647 } ) ?;
3748
38- let stage_msg = doget
39- . stage_proto
40- . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage proto" ) ) ?;
49+ let partition = doget. partition as usize ;
50+ let task_number = doget . task_number as usize ;
51+ let ( mut state , stage) = self . get_state_and_stage ( doget ) . await ?;
4152
42- let state_builder = SessionStateBuilder :: new ( )
43- . with_runtime_env ( Arc :: clone ( & self . runtime ) )
44- . with_default_features ( ) ;
53+ // find out which partition group we are executing
54+ let task = stage
55+ . tasks
56+ . get ( task_number)
57+ . ok_or ( Status :: invalid_argument ( format ! (
58+ "Task number {} not found in stage {}" ,
59+ task_number,
60+ stage. name( )
61+ ) ) ) ?;
4562
46- let mut state = self . session_builder . on_new_session ( state_builder) . build ( ) ;
63+ let partition_group =
64+ PartitionGroup ( task. partition_group . iter ( ) . map ( |p| * p as usize ) . collect ( ) ) ;
65+ state. config_mut ( ) . set_extension ( Arc :: new ( partition_group) ) ;
4766
48- let function_registry = state. function_registry ( ) . ok_or ( Status :: invalid_argument (
49- "FunctionRegistry not present in newly built SessionState" ,
50- ) ) ?;
51-
52- let mut combined_codec = ComposedPhysicalExtensionCodec :: default ( ) ;
53- combined_codec. push ( DistributedCodec ) ;
54- if let Some ( ref user_codec) = get_user_codec ( state. config ( ) ) {
55- combined_codec. push_arc ( Arc :: clone ( & user_codec) ) ;
56- }
57-
58- let mut stage = stage_from_proto (
59- stage_msg,
60- function_registry,
61- & self . runtime . as_ref ( ) ,
62- & combined_codec,
63- )
64- . map_err ( |err| Status :: invalid_argument ( format ! ( "Cannot decode stage proto: {err}" ) ) ) ?;
65- let inner_plan = Arc :: clone ( & stage. plan ) ;
66-
67- // Add the extensions that might be required for ExecutionPlan nodes in the plan
68- let config = state. config_mut ( ) ;
69- config. set_extension ( Arc :: clone ( & self . channel_manager ) ) ;
70- config. set_extension ( Arc :: new ( stage) ) ;
67+ let inner_plan = stage. plan . clone ( ) ;
7168
7269 let stream = inner_plan
73- . execute ( doget . partition as usize , state. task_ctx ( ) )
70+ . execute ( partition, state. task_ctx ( ) )
7471 . map_err ( |err| Status :: internal ( format ! ( "Error executing stage plan: {err:#?}" ) ) ) ?;
7572
7673 let flight_data_stream = FlightDataEncoderBuilder :: new ( )
@@ -86,4 +83,59 @@ impl ArrowFlightEndpoint {
8683 } ,
8784 ) ) ) )
8885 }
86+
87+ async fn get_state_and_stage (
88+ & self ,
89+ doget : DoGet ,
90+ ) -> Result < ( SessionState , Arc < ExecutionStage > ) , Status > {
91+ let key = doget
92+ . stage_key
93+ . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage key" ) ) ?;
94+ let once_stage = self . stages . entry ( key) . or_default ( ) ;
95+
96+ let ( state, stage) = once_stage
97+ . get_or_try_init ( || async {
98+ let stage_proto = doget
99+ . stage_proto
100+ . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage proto" ) ) ?;
101+
102+ let state_builder = SessionStateBuilder :: new ( )
103+ . with_runtime_env ( Arc :: clone ( & self . runtime ) )
104+ . with_default_features ( ) ;
105+
106+ let mut state = self . session_builder . on_new_session ( state_builder) . build ( ) ;
107+
108+ let function_registry =
109+ state. function_registry ( ) . ok_or ( Status :: invalid_argument (
110+ "FunctionRegistry not present in newly built SessionState" ,
111+ ) ) ?;
112+
113+ let mut combined_codec = ComposedPhysicalExtensionCodec :: default ( ) ;
114+ combined_codec. push ( DistributedCodec ) ;
115+ if let Some ( ref user_codec) = get_user_codec ( state. config ( ) ) {
116+ combined_codec. push_arc ( Arc :: clone ( user_codec) ) ;
117+ }
118+
119+ let stage = stage_from_proto (
120+ stage_proto,
121+ function_registry,
122+ self . runtime . as_ref ( ) ,
123+ & combined_codec,
124+ )
125+ . map ( Arc :: new)
126+ . map_err ( |err| {
127+ Status :: invalid_argument ( format ! ( "Cannot decode stage proto: {err}" ) )
128+ } ) ?;
129+
130+ // Add the extensions that might be required for ExecutionPlan nodes in the plan
131+ let config = state. config_mut ( ) ;
132+ config. set_extension ( Arc :: clone ( & self . channel_manager ) ) ;
133+ config. set_extension ( stage. clone ( ) ) ;
134+
135+ Ok :: < _ , Status > ( ( state, stage) )
136+ } )
137+ . await ?;
138+
139+ Ok ( ( state. clone ( ) , stage. clone ( ) ) )
140+ }
89141}
0 commit comments