Skip to content

Commit a25907e

Browse files
committed
Remove stage delegation in favor of planning-time stage assignation
1 parent ebde2bf commit a25907e

22 files changed

+383
-720
lines changed

context.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use url::Url;
2+
use uuid::Uuid;
3+
4+
#[derive(Debug, Clone)]
5+
pub struct StageContext {
6+
/// Unique identifier of the Stage.
7+
pub id: Uuid,
8+
/// Number of tasks involved in the query.
9+
pub n_tasks: usize,
10+
/// Unique identifier of the input Stage.
11+
pub input_id: Uuid,
12+
/// Urls from which the current stage will need to read data.
13+
pub input_urls: Vec<Url>,
14+
}
15+
16+
#[derive(Debug, Clone)]
17+
pub struct StageTaskContext {
18+
/// Index of the current task in a stage
19+
pub task_idx: usize,
20+
}

src/channel_manager.rs

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
use async_trait::async_trait;
22
use datafusion::common::internal_datafusion_err;
33
use datafusion::error::DataFusionError;
4-
use datafusion::prelude::SessionConfig;
4+
use datafusion::execution::TaskContext;
5+
use datafusion::prelude::{SessionConfig, SessionContext};
56
use delegate::delegate;
67
use std::sync::Arc;
78
use tonic::body::BoxBody;
89
use url::Url;
910

11+
#[derive(Clone)]
1012
pub struct ChannelManager(Arc<dyn ChannelResolver + Send + Sync>);
1113

1214
impl ChannelManager {
@@ -21,29 +23,49 @@ pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService<
2123
tonic::transport::Error,
2224
>;
2325

24-
#[derive(Clone, Debug)]
25-
pub struct ArrowFlightChannel {
26-
pub url: Url,
27-
pub channel: BoxCloneSyncChannel,
28-
}
29-
26+
/// Abstracts networking details so that users can implement their own network resolution
27+
/// mechanism.
3028
#[async_trait]
3129
pub trait ChannelResolver {
32-
async fn get_n_channels(&self, n: usize) -> Result<Vec<ArrowFlightChannel>, DataFusionError>;
33-
async fn get_channel_for_url(&self, url: &Url) -> Result<ArrowFlightChannel, DataFusionError>;
30+
/// Gets all available worker URLs. Used during stage assignment.
31+
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError>;
32+
/// For a given URL, get a channel for communicating to it.
33+
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError>;
3434
}
3535

3636
impl ChannelManager {
37-
pub fn try_from_session(session: &SessionConfig) -> Result<Arc<Self>, DataFusionError> {
38-
session
39-
.get_extension::<ChannelManager>()
40-
.ok_or_else(|| internal_datafusion_err!("No extension ChannelManager"))
41-
}
42-
4337
delegate! {
4438
to self.0 {
45-
pub async fn get_n_channels(&self, n: usize) -> Result<Vec<ArrowFlightChannel>, DataFusionError>;
46-
pub async fn get_channel_for_url(&self, url: &Url) -> Result<ArrowFlightChannel, DataFusionError>;
39+
pub fn get_urls(&self) -> Result<Vec<Url>, DataFusionError>;
40+
pub async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError>;
4741
}
4842
}
4943
}
44+
45+
impl TryInto<ChannelManager> for &SessionConfig {
46+
type Error = DataFusionError;
47+
48+
fn try_into(self) -> Result<ChannelManager, Self::Error> {
49+
Ok(self
50+
.get_extension::<ChannelManager>()
51+
.ok_or_else(|| internal_datafusion_err!("No extension ChannelManager"))?
52+
.as_ref()
53+
.clone())
54+
}
55+
}
56+
57+
impl TryInto<ChannelManager> for &TaskContext {
58+
type Error = DataFusionError;
59+
60+
fn try_into(self) -> Result<ChannelManager, Self::Error> {
61+
self.session_config().try_into()
62+
}
63+
}
64+
65+
impl TryInto<ChannelManager> for &SessionContext {
66+
type Error = DataFusionError;
67+
68+
fn try_into(self) -> Result<ChannelManager, Self::Error> {
69+
self.task_ctx().as_ref().try_into()
70+
}
71+
}

src/context.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
use url::Url;
2+
use uuid::Uuid;
3+
4+
#[derive(Debug, Clone)]
5+
pub struct StageContext {
6+
/// Unique identifier of the Stage.
7+
pub id: Uuid,
8+
/// Number of tasks involved in the query.
9+
pub n_tasks: usize,
10+
/// Unique identifier of the input Stage.
11+
pub input_id: Uuid,
12+
/// Urls from which the current stage will need to read data.
13+
pub input_urls: Vec<Url>,
14+
}
15+
16+
#[derive(Debug, Clone)]
17+
pub struct StageTaskContext {
18+
/// Index of the current task in a stage
19+
pub task_idx: usize,
20+
}

src/flight_service/do_get.rs

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
11
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
2+
use crate::context::StageTaskContext;
23
use crate::errors::datafusion_error_to_tonic_status;
34
use crate::flight_service::service::ArrowFlightEndpoint;
45
use crate::plan::ArrowFlightReadExecProtoCodec;
5-
use crate::stage_delegation::{ActorContext, StageContext};
66
use arrow_flight::encode::FlightDataEncoderBuilder;
77
use arrow_flight::error::FlightError;
88
use arrow_flight::flight_service_server::FlightService;
99
use arrow_flight::Ticket;
1010
use datafusion::error::DataFusionError;
1111
use datafusion::execution::SessionStateBuilder;
1212
use datafusion::optimizer::OptimizerConfig;
13-
use datafusion::physical_expr::Partitioning;
13+
use datafusion::physical_expr::{Partitioning, PhysicalExpr};
1414
use datafusion::physical_plan::ExecutionPlan;
15-
use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning;
15+
use datafusion_proto::physical_plan::from_proto::parse_physical_exprs;
16+
use datafusion_proto::physical_plan::to_proto::serialize_physical_exprs;
1617
use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec};
17-
use datafusion_proto::protobuf::PhysicalPlanNode;
18+
use datafusion_proto::protobuf::{PhysicalExprNode, PhysicalPlanNode};
1819
use futures::TryStreamExt;
1920
use prost::Message;
2021
use std::sync::Arc;
2122
use tonic::{Request, Response, Status};
23+
use uuid::Uuid;
2224

2325
#[derive(Clone, PartialEq, ::prost::Message)]
2426
pub struct DoGet {
@@ -35,26 +37,38 @@ pub enum DoGetInner {
3537
#[derive(Clone, PartialEq, ::prost::Message)]
3638
pub struct RemotePlanExec {
3739
#[prost(message, optional, boxed, tag = "1")]
38-
plan: Option<Box<PhysicalPlanNode>>,
39-
#[prost(message, optional, tag = "2")]
40-
stage_context: Option<StageContext>,
41-
#[prost(message, optional, tag = "3")]
42-
actor_context: Option<ActorContext>,
40+
pub plan: Option<Box<PhysicalPlanNode>>,
41+
#[prost(string, tag = "2")]
42+
pub stage_id: String,
43+
#[prost(uint64, tag = "3")]
44+
pub task_idx: u64,
45+
#[prost(uint64, tag = "4")]
46+
pub output_task_idx: u64,
47+
#[prost(uint64, tag = "5")]
48+
pub output_tasks: u64,
49+
#[prost(message, repeated, tag = "6")]
50+
pub hash_expr: Vec<PhysicalExprNode>,
4351
}
4452

4553
impl DoGet {
4654
pub fn new_remote_plan_exec_ticket(
4755
plan: Arc<dyn ExecutionPlan>,
48-
stage_context: StageContext,
49-
actor_context: ActorContext,
56+
stage_id: Uuid,
57+
task_idx: usize,
58+
output_task_idx: usize,
59+
output_tasks: usize,
60+
hash_expr: &[Arc<dyn PhysicalExpr>],
5061
extension_codec: &dyn PhysicalExtensionCodec,
5162
) -> Result<Ticket, DataFusionError> {
5263
let node = PhysicalPlanNode::try_from_physical_plan(plan, extension_codec)?;
5364
let do_get = Self {
5465
inner: Some(DoGetInner::RemotePlanExec(RemotePlanExec {
5566
plan: Some(Box::new(node)),
56-
stage_context: Some(stage_context),
57-
actor_context: Some(actor_context),
67+
stage_id: stage_id.to_string(),
68+
task_idx: task_idx as u64,
69+
output_task_idx: output_task_idx as u64,
70+
output_tasks: output_tasks as u64,
71+
hash_expr: serialize_physical_exprs(hash_expr, extension_codec)?,
5872
})),
5973
};
6074
Ok(Ticket::new(do_get.encode_to_vec()))
@@ -91,14 +105,6 @@ impl ArrowFlightEndpoint {
91105
return invalid_argument("RemotePlanExec is missing the plan");
92106
};
93107

94-
let Some(stage_context) = action.stage_context else {
95-
return invalid_argument("RemotePlanExec is missing the stage context");
96-
};
97-
98-
let Some(actor_context) = action.actor_context else {
99-
return invalid_argument("RemotePlanExec is missing the actor context");
100-
};
101-
102108
let mut codec = ComposedPhysicalExtensionCodec::default();
103109
codec.push(ArrowFlightReadExecProtoCodec);
104110
codec.push_from_config(state.config());
@@ -107,40 +113,34 @@ impl ArrowFlightEndpoint {
107113
.try_into_physical_plan(function_registry, &self.runtime, &codec)
108114
.map_err(|err| Status::internal(format!("Cannot deserialize plan: {err}")))?;
109115

110-
let stage_id = stage_context.id.clone();
111-
let caller_actor_idx = actor_context.caller_actor_idx as usize;
112-
let actor_idx = actor_context.actor_idx as usize;
113-
let prev_n = stage_context.prev_actors as usize;
114-
let partitioning = match parse_protobuf_partitioning(
115-
stage_context.partitioning.as_ref(),
116+
let stage_id = Uuid::parse_str(&action.stage_id).map_err(|err| {
117+
Status::invalid_argument(format!(
118+
"Cannot parse stage id '{}': {err}",
119+
action.stage_id
120+
))
121+
})?;
122+
123+
let task_idx = action.task_idx as usize;
124+
let caller_actor_idx = action.output_task_idx as usize;
125+
let prev_n = action.output_tasks as usize;
126+
let partitioning = match parse_physical_exprs(
127+
&action.hash_expr,
116128
function_registry,
117129
&plan.schema(),
118130
&codec,
119131
) {
120-
// We need to replace the partition count in the provided Partitioning scheme with
121-
// the number of actors in the previous stage. ArrowFlightReadExec might be declaring
122-
// N partitions, but each ArrowFlightReadExec::execute(n) call will go to a different
123-
// actor in the next stage.
124-
//
125-
// Each actor in that next stage (us here) needs to expose as many partitioned streams
126-
// as actors exist on its previous stage.
127-
Ok(Some(partitioning)) => match partitioning {
128-
Partitioning::RoundRobinBatch(_) => Partitioning::RoundRobinBatch(prev_n),
129-
Partitioning::Hash(expr, _) => Partitioning::Hash(expr, prev_n),
130-
Partitioning::UnknownPartitioning(_) => Partitioning::UnknownPartitioning(prev_n),
131-
},
132-
Ok(None) => return invalid_argument("Missing partitioning"),
133-
Err(err) => return invalid_argument(format!("Cannot parse partitioning {err}")),
132+
Ok(expr) if expr.is_empty() => Partitioning::Hash(expr, prev_n),
133+
Ok(_) => Partitioning::RoundRobinBatch(prev_n),
134+
Err(err) => return invalid_argument(format!("Cannot parse hash expressions {err}")),
134135
};
136+
135137
let config = state.config_mut();
136-
config.set_extension(Arc::clone(&self.stage_delegation));
137138
config.set_extension(Arc::clone(&self.channel_manager));
138-
config.set_extension(Arc::new(stage_context));
139-
config.set_extension(Arc::new(actor_context));
139+
config.set_extension(Arc::new(StageTaskContext { task_idx }));
140140

141141
let stream_partitioner = self
142142
.partitioner_registry
143-
.get_or_create_stream_partitioner(stage_id, actor_idx, plan, partitioning)
143+
.get_or_create_stream_partitioner(stage_id, task_idx, plan, partitioning)
144144
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
145145

146146
let stream = stream_partitioner

src/flight_service/do_put.rs

Lines changed: 0 additions & 81 deletions
This file was deleted.

src/flight_service/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
mod do_get;
2-
mod do_put;
32
mod service;
43
mod session_builder;
54
mod stream_partitioner_registry;
65

76
pub(crate) use do_get::DoGet;
8-
pub(crate) use do_put::DoPut;
97

108
pub use service::ArrowFlightEndpoint;
119
pub use session_builder::SessionBuilder;

0 commit comments

Comments
 (0)