Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use url::Url;
use uuid::Uuid;

#[derive(Debug, Clone)]
pub struct StageContext {
/// Unique identifier of the Stage.
pub id: Uuid,
/// Number of tasks involved in the query.
pub n_tasks: usize,
/// Unique identifier of the input Stage.
pub input_id: Uuid,
/// Urls from which the current stage will need to read data.
pub input_urls: Vec<Url>,
}

#[derive(Debug, Clone)]
pub struct StageTaskContext {
/// Index of the current task in a stage
pub task_idx: usize,
}
56 changes: 39 additions & 17 deletions src/channel_manager.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use async_trait::async_trait;
use datafusion::common::internal_datafusion_err;
use datafusion::error::DataFusionError;
use datafusion::prelude::SessionConfig;
use datafusion::execution::TaskContext;
use datafusion::prelude::{SessionConfig, SessionContext};
use delegate::delegate;
use std::sync::Arc;
use tonic::body::BoxBody;
use url::Url;

#[derive(Clone)]
pub struct ChannelManager(Arc<dyn ChannelResolver + Send + Sync>);

impl ChannelManager {
Expand All @@ -21,29 +23,49 @@ pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService<
tonic::transport::Error,
>;

#[derive(Clone, Debug)]
pub struct ArrowFlightChannel {
pub url: Url,
pub channel: BoxCloneSyncChannel,
}

/// Abstracts networking details so that users can implement their own network resolution
/// mechanism.
#[async_trait]
pub trait ChannelResolver {
async fn get_n_channels(&self, n: usize) -> Result<Vec<ArrowFlightChannel>, DataFusionError>;
async fn get_channel_for_url(&self, url: &Url) -> Result<ArrowFlightChannel, DataFusionError>;
/// Gets all available worker URLs. Used during stage assignment.
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError>;
/// For a given URL, get a channel for communicating to it.
async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError>;
}

impl ChannelManager {
pub fn try_from_session(session: &SessionConfig) -> Result<Arc<Self>, DataFusionError> {
session
.get_extension::<ChannelManager>()
.ok_or_else(|| internal_datafusion_err!("No extension ChannelManager"))
}

delegate! {
to self.0 {
pub async fn get_n_channels(&self, n: usize) -> Result<Vec<ArrowFlightChannel>, DataFusionError>;
pub async fn get_channel_for_url(&self, url: &Url) -> Result<ArrowFlightChannel, DataFusionError>;
pub fn get_urls(&self) -> Result<Vec<Url>, DataFusionError>;
pub async fn get_channel_for_url(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError>;
}
}
}

impl TryInto<ChannelManager> for &SessionConfig {
type Error = DataFusionError;

fn try_into(self) -> Result<ChannelManager, Self::Error> {
Ok(self
.get_extension::<ChannelManager>()
.ok_or_else(|| internal_datafusion_err!("No extension ChannelManager"))?
.as_ref()
.clone())
}
}

impl TryInto<ChannelManager> for &TaskContext {
type Error = DataFusionError;

fn try_into(self) -> Result<ChannelManager, Self::Error> {
self.session_config().try_into()
}
}

impl TryInto<ChannelManager> for &SessionContext {
type Error = DataFusionError;

fn try_into(self) -> Result<ChannelManager, Self::Error> {
self.task_ctx().as_ref().try_into()
}
}
20 changes: 20 additions & 0 deletions src/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use url::Url;
use uuid::Uuid;

#[derive(Debug, Clone)]
pub struct StageContext {
/// Unique identifier of the Stage.
pub id: Uuid,
/// Number of tasks involved in the query.
pub n_tasks: usize,
/// Unique identifier of the input Stage.
pub input_id: Uuid,
/// Urls from which the current stage will need to read data.
pub input_urls: Vec<Url>,
}

#[derive(Debug, Clone)]
pub struct StageTaskContext {
/// Index of the current task in a stage
pub task_idx: usize,
}
90 changes: 45 additions & 45 deletions src/flight_service/do_get.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
use crate::context::StageTaskContext;
use crate::errors::datafusion_error_to_tonic_status;
use crate::flight_service::service::ArrowFlightEndpoint;
use crate::plan::ArrowFlightReadExecProtoCodec;
use crate::stage_delegation::{ActorContext, StageContext};
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_server::FlightService;
use arrow_flight::Ticket;
use datafusion::error::DataFusionError;
use datafusion::execution::SessionStateBuilder;
use datafusion::optimizer::OptimizerConfig;
use datafusion::physical_expr::Partitioning;
use datafusion::physical_expr::{Partitioning, PhysicalExpr};
use datafusion::physical_plan::ExecutionPlan;
use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning;
use datafusion_proto::physical_plan::from_proto::parse_physical_exprs;
use datafusion_proto::physical_plan::to_proto::serialize_physical_exprs;
use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec};
use datafusion_proto::protobuf::PhysicalPlanNode;
use datafusion_proto::protobuf::{PhysicalExprNode, PhysicalPlanNode};
use futures::TryStreamExt;
use prost::Message;
use std::sync::Arc;
use tonic::{Request, Response, Status};
use uuid::Uuid;

#[derive(Clone, PartialEq, ::prost::Message)]
pub struct DoGet {
Expand All @@ -35,26 +37,38 @@ pub enum DoGetInner {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct RemotePlanExec {
#[prost(message, optional, boxed, tag = "1")]
plan: Option<Box<PhysicalPlanNode>>,
#[prost(message, optional, tag = "2")]
stage_context: Option<StageContext>,
#[prost(message, optional, tag = "3")]
actor_context: Option<ActorContext>,
pub plan: Option<Box<PhysicalPlanNode>>,
#[prost(string, tag = "2")]
pub stage_id: String,
#[prost(uint64, tag = "3")]
pub task_idx: u64,
#[prost(uint64, tag = "4")]
pub output_task_idx: u64,
#[prost(uint64, tag = "5")]
pub output_tasks: u64,
#[prost(message, repeated, tag = "6")]
pub hash_expr: Vec<PhysicalExprNode>,
}

impl DoGet {
pub fn new_remote_plan_exec_ticket(
plan: Arc<dyn ExecutionPlan>,
stage_context: StageContext,
actor_context: ActorContext,
stage_id: Uuid,
task_idx: usize,
output_task_idx: usize,
output_tasks: usize,
hash_expr: &[Arc<dyn PhysicalExpr>],
extension_codec: &dyn PhysicalExtensionCodec,
) -> Result<Ticket, DataFusionError> {
let node = PhysicalPlanNode::try_from_physical_plan(plan, extension_codec)?;
let do_get = Self {
inner: Some(DoGetInner::RemotePlanExec(RemotePlanExec {
plan: Some(Box::new(node)),
stage_context: Some(stage_context),
actor_context: Some(actor_context),
stage_id: stage_id.to_string(),
task_idx: task_idx as u64,
output_task_idx: output_task_idx as u64,
output_tasks: output_tasks as u64,
hash_expr: serialize_physical_exprs(hash_expr, extension_codec)?,
})),
};
Ok(Ticket::new(do_get.encode_to_vec()))
Expand Down Expand Up @@ -91,14 +105,6 @@ impl ArrowFlightEndpoint {
return invalid_argument("RemotePlanExec is missing the plan");
};

let Some(stage_context) = action.stage_context else {
return invalid_argument("RemotePlanExec is missing the stage context");
};

let Some(actor_context) = action.actor_context else {
return invalid_argument("RemotePlanExec is missing the actor context");
};

let mut codec = ComposedPhysicalExtensionCodec::default();
codec.push(ArrowFlightReadExecProtoCodec);
codec.push_from_config(state.config());
Expand All @@ -107,40 +113,34 @@ impl ArrowFlightEndpoint {
.try_into_physical_plan(function_registry, &self.runtime, &codec)
.map_err(|err| Status::internal(format!("Cannot deserialize plan: {err}")))?;

let stage_id = stage_context.id.clone();
let caller_actor_idx = actor_context.caller_actor_idx as usize;
let actor_idx = actor_context.actor_idx as usize;
let prev_n = stage_context.prev_actors as usize;
let partitioning = match parse_protobuf_partitioning(
stage_context.partitioning.as_ref(),
let stage_id = Uuid::parse_str(&action.stage_id).map_err(|err| {
Status::invalid_argument(format!(
"Cannot parse stage id '{}': {err}",
action.stage_id
))
})?;

let task_idx = action.task_idx as usize;
let caller_actor_idx = action.output_task_idx as usize;
let prev_n = action.output_tasks as usize;
let partitioning = match parse_physical_exprs(
&action.hash_expr,
function_registry,
&plan.schema(),
&codec,
) {
// We need to replace the partition count in the provided Partitioning scheme with
// the number of actors in the previous stage. ArrowFlightReadExec might be declaring
// N partitions, but each ArrowFlightReadExec::execute(n) call will go to a different
// actor in the next stage.
//
// Each actor in that next stage (us here) needs to expose as many partitioned streams
// as actors exist on its previous stage.
Ok(Some(partitioning)) => match partitioning {
Partitioning::RoundRobinBatch(_) => Partitioning::RoundRobinBatch(prev_n),
Partitioning::Hash(expr, _) => Partitioning::Hash(expr, prev_n),
Partitioning::UnknownPartitioning(_) => Partitioning::UnknownPartitioning(prev_n),
},
Ok(None) => return invalid_argument("Missing partitioning"),
Err(err) => return invalid_argument(format!("Cannot parse partitioning {err}")),
Ok(expr) if expr.is_empty() => Partitioning::Hash(expr, prev_n),
Ok(_) => Partitioning::RoundRobinBatch(prev_n),
Err(err) => return invalid_argument(format!("Cannot parse hash expressions {err}")),
};

let config = state.config_mut();
config.set_extension(Arc::clone(&self.stage_delegation));
config.set_extension(Arc::clone(&self.channel_manager));
config.set_extension(Arc::new(stage_context));
config.set_extension(Arc::new(actor_context));
config.set_extension(Arc::new(StageTaskContext { task_idx }));

let stream_partitioner = self
.partitioner_registry
.get_or_create_stream_partitioner(stage_id, actor_idx, plan, partitioning)
.get_or_create_stream_partitioner(stage_id, task_idx, plan, partitioning)
.map_err(|err| datafusion_error_to_tonic_status(&err))?;

let stream = stream_partitioner
Expand Down
81 changes: 0 additions & 81 deletions src/flight_service/do_put.rs

This file was deleted.

2 changes: 0 additions & 2 deletions src/flight_service/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
mod do_get;
mod do_put;
mod service;
mod session_builder;
mod stream_partitioner_registry;

pub(crate) use do_get::DoGet;
pub(crate) use do_put::DoPut;

pub use service::ArrowFlightEndpoint;
pub use session_builder::SessionBuilder;
Loading