@@ -10,15 +10,30 @@ use arrow_flight::{
1010use async_trait:: async_trait;
1111use datafusion:: error:: DataFusionError ;
1212use datafusion:: execution:: runtime_env:: RuntimeEnv ;
13+ use datafusion:: physical_plan:: ExecutionPlan ;
1314use futures:: stream:: BoxStream ;
1415use std:: sync:: Arc ;
1516use tokio:: sync:: OnceCell ;
1617use tonic:: { Request , Response , Status , Streaming } ;
1718
19+ #[ allow( clippy:: type_complexity) ]
20+ pub ( super ) struct ArrowFlightEndpointHooks {
21+ pub ( super ) on_plan : Arc < dyn Fn ( Arc < dyn ExecutionPlan > ) -> Arc < dyn ExecutionPlan > + Sync + Send > ,
22+ }
23+
24+ impl Default for ArrowFlightEndpointHooks {
25+ fn default ( ) -> Self {
26+ Self {
27+ on_plan : Arc :: new ( |plan| plan) ,
28+ }
29+ }
30+ }
31+
1832pub struct ArrowFlightEndpoint {
1933 pub ( super ) runtime : Arc < RuntimeEnv > ,
2034 pub ( super ) task_data_entries : Arc < TTLMap < StageKey , Arc < OnceCell < TaskData > > > > ,
2135 pub ( super ) session_builder : Arc < dyn DistributedSessionBuilder + Send + Sync > ,
36+ pub ( super ) hooks : ArrowFlightEndpointHooks ,
2237}
2338
2439impl ArrowFlightEndpoint {
@@ -30,8 +45,21 @@ impl ArrowFlightEndpoint {
3045 runtime : Arc :: new ( RuntimeEnv :: default ( ) ) ,
3146 task_data_entries : Arc :: new ( ttl_map) ,
3247 session_builder : Arc :: new ( session_builder) ,
48+ hooks : ArrowFlightEndpointHooks :: default ( ) ,
3349 } )
3450 }
51+
52+ /// Adds a callback for when an [ExecutionPlan] is received in the `do_get` call.
53+ ///
54+ /// The callback takes the plan and returns another plan that must be either the same,
55+ /// or equivalent in terms of execution. Mutating the plan by adding nodes or removing them
56+ /// will make the query blow up in unexpected ways.
57+ pub fn on_plan (
58+ & mut self ,
59+ cbk : impl Fn ( Arc < dyn ExecutionPlan > ) -> Arc < dyn ExecutionPlan > + Sync + Send + ' static ,
60+ ) {
61+ self . hooks . on_plan = Arc :: new ( cbk) ;
62+ }
3563}
3664
3765#[ async_trait]
0 commit comments