11use crate :: composed_extension_codec:: ComposedPhysicalExtensionCodec ;
22use crate :: errors:: datafusion_error_to_tonic_status;
33use crate :: flight_service:: service:: ArrowFlightEndpoint ;
4- use crate :: plan:: DistributedCodec ;
5- use crate :: stage:: { stage_from_proto, ExecutionStageProto } ;
4+ use crate :: plan:: { DistributedCodec , PartitionGroup } ;
5+ use crate :: stage:: { stage_from_proto, ExecutionStage , ExecutionStageProto } ;
66use crate :: user_provided_codec:: get_user_codec;
77use arrow_flight:: encode:: FlightDataEncoderBuilder ;
88use arrow_flight:: error:: FlightError ;
99use arrow_flight:: flight_service_server:: FlightService ;
1010use arrow_flight:: Ticket ;
11- use datafusion:: execution:: SessionStateBuilder ;
11+ use datafusion:: execution:: { SessionState , SessionStateBuilder } ;
1212use datafusion:: optimizer:: OptimizerConfig ;
13- use datafusion:: prelude:: SessionContext ;
1413use futures:: TryStreamExt ;
1514use prost:: Message ;
1615use std:: sync:: Arc ;
1716use tonic:: { Request , Response , Status } ;
1817
18+ use super :: service:: StageKey ;
19+
1920#[ derive( Clone , PartialEq , :: prost:: Message ) ]
2021pub struct DoGet {
2122 /// The ExecutionStage that we are going to execute
2223 #[ prost( message, optional, tag = "1" ) ]
2324 pub stage_proto : Option < ExecutionStageProto > ,
24- /// the partition of the stage to execute
25+ /// The index to the task within the stage that we want to execute
2526 #[ prost( uint64, tag = "2" ) ]
27+ pub task_number : u64 ,
28+ /// the partition number we want to execute
29+ #[ prost( uint64, tag = "3" ) ]
2630 pub partition : u64 ,
31+ /// The stage key that identifies the stage. This is useful to keep
32+ /// outside of the stage proto as it is used to store the stage
33+ /// and we may not need to deserialize the entire stage proto
34+ /// if we already have stored it
35+ #[ prost( message, optional, tag = "4" ) ]
36+ pub stage_key : Option < StageKey > ,
2737}
2838
2939impl ArrowFlightEndpoint {
@@ -36,59 +46,28 @@ impl ArrowFlightEndpoint {
3646 Status :: invalid_argument ( format ! ( "Cannot decode DoGet message: {err}" ) )
3747 } ) ?;
3848
39- let stage_msg = doget
40- . stage_proto
41- . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage proto" ) ) ?;
42-
43- let state_builder = SessionStateBuilder :: new ( )
44- . with_runtime_env ( Arc :: clone ( & self . runtime ) )
45- . with_default_features ( ) ;
46- let state_builder = self
47- . session_builder
48- . session_state_builder ( state_builder)
49- . map_err ( |err| datafusion_error_to_tonic_status ( & err) ) ?;
50-
51- let state = state_builder. build ( ) ;
52- let mut state = self
53- . session_builder
54- . session_state ( state)
55- . await
56- . map_err ( |err| datafusion_error_to_tonic_status ( & err) ) ?;
57-
58- let function_registry = state. function_registry ( ) . ok_or ( Status :: invalid_argument (
59- "FunctionRegistry not present in newly built SessionState" ,
60- ) ) ?;
61-
62- let mut combined_codec = ComposedPhysicalExtensionCodec :: default ( ) ;
63- combined_codec. push ( DistributedCodec ) ;
64- if let Some ( ref user_codec) = get_user_codec ( state. config ( ) ) {
65- combined_codec. push_arc ( Arc :: clone ( & user_codec) ) ;
66- }
67-
68- let stage = stage_from_proto (
69- stage_msg,
70- function_registry,
71- & self . runtime . as_ref ( ) ,
72- & combined_codec,
73- )
74- . map_err ( |err| Status :: invalid_argument ( format ! ( "Cannot decode stage proto: {err}" ) ) ) ?;
75- let inner_plan = Arc :: clone ( & stage. plan ) ;
76-
77- // Add the extensions that might be required for ExecutionPlan nodes in the plan
78- let config = state. config_mut ( ) ;
79- config. set_extension ( Arc :: clone ( & self . channel_manager ) ) ;
80- config. set_extension ( Arc :: new ( stage) ) ;
81-
82- let ctx = SessionContext :: new_with_state ( state) ;
83-
84- let ctx = self
85- . session_builder
86- . session_context ( ctx)
87- . await
88- . map_err ( |err| datafusion_error_to_tonic_status ( & err) ) ?;
49+ let partition = doget. partition as usize ;
50+ let task_number = doget. task_number as usize ;
51+ let ( mut state, stage) = self . get_state_and_stage ( doget) . await ?;
52+
53+ // find out which partition group we are executing
54+ let task = stage
55+ . tasks
56+ . get ( task_number)
57+ . ok_or ( Status :: invalid_argument ( format ! (
58+ "Task number {} not found in stage {}" ,
59+ task_number,
60+ stage. name( )
61+ ) ) ) ?;
62+
63+ let partition_group =
64+ PartitionGroup ( task. partition_group . iter ( ) . map ( |p| * p as usize ) . collect ( ) ) ;
65+ state. config_mut ( ) . set_extension ( Arc :: new ( partition_group) ) ;
66+
67+ let inner_plan = stage. plan . clone ( ) ;
8968
9069 let stream = inner_plan
91- . execute ( doget . partition as usize , ctx . task_ctx ( ) )
70+ . execute ( partition, state . task_ctx ( ) )
9271 . map_err ( |err| Status :: internal ( format ! ( "Error executing stage plan: {err:#?}" ) ) ) ?;
9372
9473 let flight_data_stream = FlightDataEncoderBuilder :: new ( )
@@ -104,4 +83,71 @@ impl ArrowFlightEndpoint {
10483 } ,
10584 ) ) ) )
10685 }
86+
87+ async fn get_state_and_stage (
88+ & self ,
89+ doget : DoGet ,
90+ ) -> Result < ( SessionState , Arc < ExecutionStage > ) , Status > {
91+ let key = doget
92+ . stage_key
93+ . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage key" ) ) ?;
94+ let once_stage = {
95+ let entry = self . stages . entry ( key) . or_default ( ) ;
96+ Arc :: clone ( & entry)
97+ } ;
98+
99+ let ( state, stage) = once_stage
100+ . get_or_try_init ( || async {
101+ let stage_proto = doget
102+ . stage_proto
103+ . ok_or ( Status :: invalid_argument ( "DoGet is missing the stage proto" ) ) ?;
104+
105+ let state_builder = SessionStateBuilder :: new ( )
106+ . with_runtime_env ( Arc :: clone ( & self . runtime ) )
107+ . with_default_features ( ) ;
108+ let state_builder = self
109+ . session_builder
110+ . session_state_builder ( state_builder)
111+ . map_err ( |err| datafusion_error_to_tonic_status ( & err) ) ?;
112+
113+ let state = state_builder. build ( ) ;
114+ let mut state = self
115+ . session_builder
116+ . session_state ( state)
117+ . await
118+ . map_err ( |err| datafusion_error_to_tonic_status ( & err) ) ?;
119+
120+ let function_registry =
121+ state. function_registry ( ) . ok_or ( Status :: invalid_argument (
122+ "FunctionRegistry not present in newly built SessionState" ,
123+ ) ) ?;
124+
125+ let mut combined_codec = ComposedPhysicalExtensionCodec :: default ( ) ;
126+ combined_codec. push ( DistributedCodec ) ;
127+ if let Some ( ref user_codec) = get_user_codec ( state. config ( ) ) {
128+ combined_codec. push_arc ( Arc :: clone ( user_codec) ) ;
129+ }
130+
131+ let stage = stage_from_proto (
132+ stage_proto,
133+ function_registry,
134+ self . runtime . as_ref ( ) ,
135+ & combined_codec,
136+ )
137+ . map ( Arc :: new)
138+ . map_err ( |err| {
139+ Status :: invalid_argument ( format ! ( "Cannot decode stage proto: {err}" ) )
140+ } ) ?;
141+
142+ // Add the extensions that might be required for ExecutionPlan nodes in the plan
143+ let config = state. config_mut ( ) ;
144+ config. set_extension ( Arc :: clone ( & self . channel_manager ) ) ;
145+ config. set_extension ( stage. clone ( ) ) ;
146+
147+ Ok :: < _ , Status > ( ( state, stage) )
148+ } )
149+ . await ?;
150+
151+ Ok ( ( state. clone ( ) , stage. clone ( ) ) )
152+ }
107153}
0 commit comments