@@ -10,15 +10,24 @@ use arrow_flight::{
1010use async_trait:: async_trait;
1111use datafusion:: error:: DataFusionError ;
1212use datafusion:: execution:: runtime_env:: RuntimeEnv ;
13+ use datafusion:: physical_plan:: ExecutionPlan ;
1314use futures:: stream:: BoxStream ;
1415use std:: sync:: Arc ;
1516use tokio:: sync:: OnceCell ;
1617use tonic:: { Request , Response , Status , Streaming } ;
1718
19+ #[ allow( clippy:: type_complexity) ]
20+ #[ derive( Default ) ]
21+ pub ( super ) struct ArrowFlightEndpointHooks {
22+ pub ( super ) on_plan :
23+ Vec < Arc < dyn Fn ( Arc < dyn ExecutionPlan > ) -> Arc < dyn ExecutionPlan > + Sync + Send > > ,
24+ }
25+
1826pub struct ArrowFlightEndpoint {
1927 pub ( super ) runtime : Arc < RuntimeEnv > ,
2028 pub ( super ) task_data_entries : Arc < TTLMap < StageKey , Arc < OnceCell < TaskData > > > > ,
2129 pub ( super ) session_builder : Arc < dyn DistributedSessionBuilder + Send + Sync > ,
30+ pub ( super ) hooks : ArrowFlightEndpointHooks ,
2231}
2332
2433impl ArrowFlightEndpoint {
@@ -30,8 +39,21 @@ impl ArrowFlightEndpoint {
3039 runtime : Arc :: new ( RuntimeEnv :: default ( ) ) ,
3140 task_data_entries : Arc :: new ( ttl_map) ,
3241 session_builder : Arc :: new ( session_builder) ,
42+ hooks : ArrowFlightEndpointHooks :: default ( ) ,
3343 } )
3444 }
45+
46+ /// Adds a callback for when an [ExecutionPlan] is received in the `do_get` call.
47+ ///
48+ /// The callback takes the plan and returns another plan that must be either the same,
49+ /// or equivalent in terms of execution. Mutating the plan by adding nodes or removing them
50+ /// will make the query blow up in unexpected ways.
51+ pub fn add_on_plan_hook (
52+ & mut self ,
53+ hook : impl Fn ( Arc < dyn ExecutionPlan > ) -> Arc < dyn ExecutionPlan > + Sync + Send + ' static ,
54+ ) {
55+ self . hooks . on_plan . push ( Arc :: new ( hook) ) ;
56+ }
3557}
3658
3759#[ async_trait]
0 commit comments