Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ impl ArrowFlightEndpoint {
let stage_data = once
.get_or_try_init(|| async {
let proto_node = PhysicalPlanNode::try_decode(doget.plan_proto.as_ref())?;
let plan = proto_node.try_into_physical_plan(&ctx, &self.runtime, &codec)?;
let mut plan = proto_node.try_into_physical_plan(&ctx, &self.runtime, &codec)?;
for hook in self.hooks.on_plan.iter() {
plan = hook(plan)
}

// Initialize partition count to the number of partitions in the stage
let total_partitions = plan.properties().partitioning.partition_count();
Expand Down Expand Up @@ -226,8 +229,16 @@ mod tests {
#[tokio::test]
async fn test_task_data_partition_counting() {
// Create ArrowFlightEndpoint with DefaultSessionBuilder
let endpoint =
let mut endpoint =
ArrowFlightEndpoint::try_new(DefaultSessionBuilder).expect("Failed to create endpoint");
let plans_received = Arc::new(AtomicUsize::default());
{
let plans_received = Arc::clone(&plans_received);
endpoint.add_on_plan_hook(move |plan| {
plans_received.fetch_add(1, Ordering::SeqCst);
plan
});
}

// Create 3 tasks with 3 partitions each.
let num_tasks = 3;
Expand Down Expand Up @@ -297,6 +308,8 @@ mod tests {
assert!(result.is_ok());
}
}
// As many plans as tasks should have been received.
assert_eq!(plans_received.load(Ordering::SeqCst), task_keys.len());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker but an independent test where you swap out a plan node (ex. DataSourceExec with EmptyExec) would be nice

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 but that would fail pretty bad, part of the contract with this hook is that people should not do that. Or do you expect that to pass?

Copy link
Collaborator

@jayshrivastava jayshrivastava Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think swapping out a leaf node like DataSourceExec will work because it preserves the plan structure, as long as the schema matches.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that's not something we are willing to support right? the contract is more something like: "don't touch anything that changes the plan", we just don't enforce it at runtime for performance reasons.


// Check that the endpoint has not evicted any task states.
assert_eq!(endpoint.task_data_entries.len(), num_tasks as usize);
Expand Down
22 changes: 22 additions & 0 deletions src/flight_service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,24 @@ use arrow_flight::{
use async_trait::async_trait;
use datafusion::error::DataFusionError;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::physical_plan::ExecutionPlan;
use futures::stream::BoxStream;
use std::sync::Arc;
use tokio::sync::OnceCell;
use tonic::{Request, Response, Status, Streaming};

#[allow(clippy::type_complexity)]
#[derive(Default)]
pub(super) struct ArrowFlightEndpointHooks {
pub(super) on_plan:
Vec<Arc<dyn Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send>>,
}

pub struct ArrowFlightEndpoint {
pub(super) runtime: Arc<RuntimeEnv>,
pub(super) task_data_entries: Arc<TTLMap<StageKey, Arc<OnceCell<TaskData>>>>,
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
pub(super) hooks: ArrowFlightEndpointHooks,
}

impl ArrowFlightEndpoint {
Expand All @@ -30,8 +39,21 @@ impl ArrowFlightEndpoint {
runtime: Arc::new(RuntimeEnv::default()),
task_data_entries: Arc::new(ttl_map),
session_builder: Arc::new(session_builder),
hooks: ArrowFlightEndpointHooks::default(),
})
}

/// Adds a callback for when an [ExecutionPlan] is received in the `do_get` call.
///
/// The callback takes the plan and returns another plan that must be either the same,
/// or equivalent in terms of execution. Mutating the plan by adding nodes or removing them
/// will make the query blow up in unexpected ways.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add an assertion that the final schema matches within the ArrowFlightEndpoint so we fail early with "schema mismatch in hook" to make it clear that a hook caused this sort of error.

The harder part is asserting that the plan structure is the same, which is important for metrics. Traversing the plan to assert this would be expensive. Since we only support wrapping (because the plan structure cannot change), can we leverage with_new_children in the hook itself?

ex.

trait Hook {
    fn apply(Arc<dyn ExecutionPlan>) -> bool
    fn new(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan>
}


// in the endpoint

fn apply_hook<H: Hook>(hook: H, plan: Arc<dyn ExecutionPlan>) {
    plan.transform_up(
        |node| {
            if hook.apply(node) {
                Transformed::yes(hook.new(node).with_new_children(node.children()))
            }
        } 
   )
}

This makes hooks a bit more expensive bc we necessarily traverse the whole plan, where previously we we were not. But for datafusion-tracing, we would have traversed the whole plan anyways.

This might be more effort than it's worth so I'll leave the decision up to you. If you do implement something like this, then I'm happy to take another look :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 I imagine in most cases there's just going to be no hook, I think it's worth not penalizing those cases.

The "schema mismatch" might be a good idea though

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that makes sense

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for now I'm going to keep it simple and not do anything. In case we find it necessary to add some checks we can refer back to this conversation.

pub fn add_on_plan_hook(
&mut self,
hook: impl Fn(Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> + Sync + Send + 'static,
) {
self.hooks.on_plan.push(Arc::new(hook));
}
}

#[async_trait]
Expand Down