Skip to content

Commit 95667be

Browse files
committed
fix execution by storing stages in a OnceCell
1 parent 8650db2 commit 95667be

File tree

15 files changed

+248
-250
lines changed

15 files changed

+248
-250
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dashmap = "6.1.0"
2222
prost = "0.13.5"
2323
rand = "0.8.5"
2424
object_store = "0.12.3"
25+
async-stream = "0.3.6"
2526

2627
[dev-dependencies]
2728
insta = { version = "1.43.1", features = ["filters"] }

src/flight_service/do_get.rs

Lines changed: 87 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,39 @@
11
use crate::composed_extension_codec::ComposedPhysicalExtensionCodec;
22
use crate::errors::datafusion_error_to_tonic_status;
33
use crate::flight_service::service::ArrowFlightEndpoint;
4-
use crate::plan::DistributedCodec;
5-
use crate::stage::{stage_from_proto, ExecutionStageProto};
4+
use crate::plan::{DistributedCodec, PartitionGroup};
5+
use crate::stage::{stage_from_proto, ExecutionStage, ExecutionStageProto, StageKey};
66
use crate::user_provided_codec::get_user_codec;
77
use arrow_flight::encode::FlightDataEncoderBuilder;
88
use arrow_flight::error::FlightError;
99
use arrow_flight::flight_service_server::FlightService;
1010
use arrow_flight::Ticket;
11-
use datafusion::execution::SessionStateBuilder;
11+
use datafusion::execution::{SessionState, SessionStateBuilder};
1212
use datafusion::optimizer::OptimizerConfig;
13+
use datafusion::physical_plan::ExecutionPlan;
1314
use futures::TryStreamExt;
1415
use prost::Message;
1516
use std::sync::Arc;
17+
use tokio::sync::OnceCell;
1618
use tonic::{Request, Response, Status};
1719

1820
#[derive(Clone, PartialEq, ::prost::Message)]
1921
pub struct DoGet {
2022
/// The ExecutionStage that we are going to execute
2123
#[prost(message, optional, tag = "1")]
2224
pub stage_proto: Option<ExecutionStageProto>,
23-
/// the partition of the stage to execute
25+
/// The index to the task within the stage that we want to execute
2426
#[prost(uint64, tag = "2")]
27+
pub task_number: u64,
28+
/// the partition number we want to execute
29+
#[prost(uint64, tag = "3")]
2530
pub partition: u64,
31+
/// The stage key that identifies the stage. This is useful to keep
32+
/// outside of the stage proto as it is used to store the stage
33+
/// and we may not need to deserialize the entire stage proto
34+
/// if we already have stored it
35+
#[prost(message, optional, tag = "4")]
36+
pub stage_key: Option<StageKey>,
2637
}
2738

2839
impl ArrowFlightEndpoint {
@@ -35,42 +46,28 @@ impl ArrowFlightEndpoint {
3546
Status::invalid_argument(format!("Cannot decode DoGet message: {err}"))
3647
})?;
3748

38-
let stage_msg = doget
39-
.stage_proto
40-
.ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?;
49+
let partition = doget.partition as usize;
50+
let task_number = doget.task_number as usize;
51+
let (mut state, stage) = self.get_state_and_stage(doget).await?;
4152

42-
let state_builder = SessionStateBuilder::new()
43-
.with_runtime_env(Arc::clone(&self.runtime))
44-
.with_default_features();
53+
// find out which partition group we are executing
54+
let task = stage
55+
.tasks
56+
.get(task_number)
57+
.ok_or(Status::invalid_argument(format!(
58+
"Task number {} not found in stage {}",
59+
task_number,
60+
stage.name()
61+
)))?;
4562

46-
let mut state = self.session_builder.on_new_session(state_builder).build();
63+
let partition_group =
64+
PartitionGroup(task.partition_group.iter().map(|p| *p as usize).collect());
65+
state.config_mut().set_extension(Arc::new(partition_group));
4766

48-
let function_registry = state.function_registry().ok_or(Status::invalid_argument(
49-
"FunctionRegistry not present in newly built SessionState",
50-
))?;
51-
52-
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
53-
combined_codec.push(DistributedCodec);
54-
if let Some(ref user_codec) = get_user_codec(state.config()) {
55-
combined_codec.push_arc(Arc::clone(&user_codec));
56-
}
57-
58-
let mut stage = stage_from_proto(
59-
stage_msg,
60-
function_registry,
61-
&self.runtime.as_ref(),
62-
&combined_codec,
63-
)
64-
.map_err(|err| Status::invalid_argument(format!("Cannot decode stage proto: {err}")))?;
65-
let inner_plan = Arc::clone(&stage.plan);
66-
67-
// Add the extensions that might be required for ExecutionPlan nodes in the plan
68-
let config = state.config_mut();
69-
config.set_extension(Arc::clone(&self.channel_manager));
70-
config.set_extension(Arc::new(stage));
67+
let inner_plan = stage.plan.clone();
7168

7269
let stream = inner_plan
73-
.execute(doget.partition as usize, state.task_ctx())
70+
.execute(partition, state.task_ctx())
7471
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;
7572

7673
let flight_data_stream = FlightDataEncoderBuilder::new()
@@ -86,4 +83,59 @@ impl ArrowFlightEndpoint {
8683
},
8784
))))
8885
}
86+
87+
async fn get_state_and_stage(
88+
&self,
89+
doget: DoGet,
90+
) -> Result<(SessionState, Arc<ExecutionStage>), Status> {
91+
let key = doget
92+
.stage_key
93+
.ok_or(Status::invalid_argument("DoGet is missing the stage key"))?;
94+
let once_stage = self.stages.entry(key).or_default();
95+
96+
let (state, stage) = once_stage
97+
.get_or_try_init(|| async {
98+
let stage_proto = doget
99+
.stage_proto
100+
.ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?;
101+
102+
let state_builder = SessionStateBuilder::new()
103+
.with_runtime_env(Arc::clone(&self.runtime))
104+
.with_default_features();
105+
106+
let mut state = self.session_builder.on_new_session(state_builder).build();
107+
108+
let function_registry =
109+
state.function_registry().ok_or(Status::invalid_argument(
110+
"FunctionRegistry not present in newly built SessionState",
111+
))?;
112+
113+
let mut combined_codec = ComposedPhysicalExtensionCodec::default();
114+
combined_codec.push(DistributedCodec);
115+
if let Some(ref user_codec) = get_user_codec(state.config()) {
116+
combined_codec.push_arc(Arc::clone(user_codec));
117+
}
118+
119+
let stage = stage_from_proto(
120+
stage_proto,
121+
function_registry,
122+
self.runtime.as_ref(),
123+
&combined_codec,
124+
)
125+
.map(Arc::new)
126+
.map_err(|err| {
127+
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
128+
})?;
129+
130+
// Add the extensions that might be required for ExecutionPlan nodes in the plan
131+
let config = state.config_mut();
132+
config.set_extension(Arc::clone(&self.channel_manager));
133+
config.set_extension(stage.clone());
134+
135+
Ok::<_, Status>((state, stage))
136+
})
137+
.await?;
138+
139+
Ok((state.clone(), stage.clone()))
140+
}
89141
}

src/flight_service/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
mod do_get;
22
mod service;
33
mod session_builder;
4-
mod stream_partitioner_registry;
54

65
pub(crate) use do_get::DoGet;
76

src/flight_service/service.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
use crate::channel_manager::ChannelManager;
22
use crate::flight_service::session_builder::NoopSessionBuilder;
3-
use crate::flight_service::stream_partitioner_registry::StreamPartitionerRegistry;
43
use crate::flight_service::SessionBuilder;
4+
use crate::stage::{ExecutionStage, StageKey};
55
use crate::ChannelResolver;
66
use arrow_flight::flight_service_server::FlightService;
77
use arrow_flight::{
88
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
99
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
1010
};
1111
use async_trait::async_trait;
12+
use dashmap::DashMap;
1213
use datafusion::execution::runtime_env::RuntimeEnv;
14+
use datafusion::execution::SessionState;
1315
use futures::stream::BoxStream;
1416
use std::sync::Arc;
17+
use tokio::sync::OnceCell;
1518
use tonic::{Request, Response, Status, Streaming};
1619

1720
pub struct ArrowFlightEndpoint {
1821
pub(super) channel_manager: Arc<ChannelManager>,
1922
pub(super) runtime: Arc<RuntimeEnv>,
20-
pub(super) partitioner_registry: Arc<StreamPartitionerRegistry>,
23+
pub(super) stages: DashMap<StageKey, OnceCell<(SessionState, Arc<ExecutionStage>)>>,
2124
pub(super) session_builder: Arc<dyn SessionBuilder + Send + Sync>,
2225
}
2326

@@ -26,7 +29,7 @@ impl ArrowFlightEndpoint {
2629
Self {
2730
channel_manager: Arc::new(ChannelManager::new(channel_resolver)),
2831
runtime: Arc::new(RuntimeEnv::default()),
29-
partitioner_registry: Arc::new(StreamPartitionerRegistry::default()),
32+
stages: DashMap::new(),
3033
session_builder: Arc::new(NoopSessionBuilder),
3134
}
3235
}

src/flight_service/stream_partitioner_registry.rs

Lines changed: 0 additions & 166 deletions
This file was deleted.

0 commit comments

Comments
 (0)