Skip to content

Commit 700b1de

Browse files
committed
Refactor do_get.rs and adjacent files
1 parent 355da32 commit 700b1de

File tree

5 files changed

+92
-100
lines changed

5 files changed

+92
-100
lines changed

src/channel_resolver_ext.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use async_trait::async_trait;
2+
use datafusion::common::exec_datafusion_err;
23
use datafusion::error::DataFusionError;
34
use datafusion::prelude::SessionConfig;
45
use std::sync::Arc;
@@ -16,9 +17,10 @@ pub(crate) fn set_distributed_channel_resolver(
1617

1718
pub(crate) fn get_distributed_channel_resolver(
1819
cfg: &SessionConfig,
19-
) -> Option<Arc<dyn ChannelResolver + Send + Sync>> {
20+
) -> Result<Arc<dyn ChannelResolver + Send + Sync>, DataFusionError> {
2021
cfg.get_extension::<ChannelResolverExtension>()
2122
.map(|cm| cm.0.clone())
23+
.ok_or_else(|| exec_datafusion_err!("ChannelResolver not present in the session config"))
2224
}
2325

2426
#[derive(Clone)]

src/execution_plans/arrow_flight_read.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,14 @@ impl ExecutionPlan for ArrowFlightReadExec {
147147
partition: usize,
148148
context: Arc<TaskContext>,
149149
) -> Result<SendableRecordBatchStream, DataFusionError> {
150-
let ArrowFlightReadExec::Ready(this) = self else {
150+
let ArrowFlightReadExec::Ready(self_ready) = self else {
151151
return exec_err!("ArrowFlightReadExec is not ready, was the distributed optimization step performed?");
152152
};
153153

154154
// get the channel manager and current stage from our context
155-
let Some(channel_resolver) = get_distributed_channel_resolver(context.session_config())
156-
else {
157-
return exec_err!(
158-
"ArrowFlightReadExec requires a ChannelResolver in the session config"
159-
);
160-
};
155+
let channel_resolver = get_distributed_channel_resolver(context.session_config())?;
161156

157+
// the `ArrowFlightReadExec` node can only be executed in the context of a `StageExec`
162158
let stage = context
163159
.session_config()
164160
.get_extension::<StageExec>()
@@ -170,10 +166,10 @@ impl ExecutionPlan for ArrowFlightReadExec {
170166
// reading from
171167
let child_stage = stage
172168
.child_stages_iter()
173-
.find(|s| s.num == this.stage_num)
169+
.find(|s| s.num == self_ready.stage_num)
174170
.ok_or(internal_datafusion_err!(
175171
"ArrowFlightReadExec: no child stage with num {}",
176-
this.stage_num
172+
self_ready.stage_num
177173
))?;
178174

179175
let flight_metadata = context

src/execution_plans/stage.rs

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::channel_resolver_ext::get_distributed_channel_resolver;
22
use crate::{ArrowFlightReadExec, ChannelResolver, PartitionIsolatorExec};
3-
use datafusion::common::{exec_err, internal_err};
3+
use datafusion::common::internal_err;
44
use datafusion::error::{DataFusionError, Result};
55
use datafusion::execution::TaskContext;
66
use datafusion::physical_plan::{
@@ -260,20 +260,10 @@ impl ExecutionPlan for StageExec {
260260
partition: usize,
261261
context: Arc<TaskContext>,
262262
) -> Result<datafusion::execution::SendableRecordBatchStream> {
263-
let stage = self
264-
.as_any()
265-
.downcast_ref::<StageExec>()
266-
.expect("Unwrapping myself should always work");
267-
268-
let Some(channel_resolver) = get_distributed_channel_resolver(context.session_config())
269-
else {
270-
return exec_err!("ChannelManager not found in session config");
271-
};
272-
273-
let urls = channel_resolver.get_urls()?;
263+
let channel_resolver = get_distributed_channel_resolver(context.session_config())?;
274264

275-
let assigned_stage = stage
276-
.try_assign_urls(&urls)
265+
let assigned_stage = self
266+
.try_assign_urls(&channel_resolver.get_urls()?)
277267
.map(Arc::new)
278268
.map_err(|e| DataFusionError::Execution(e.to_string()))?;
279269

src/flight_service/do_get.rs

Lines changed: 78 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@ use arrow_flight::encode::FlightDataEncoderBuilder;
99
use arrow_flight::error::FlightError;
1010
use arrow_flight::flight_service_server::FlightService;
1111
use arrow_flight::Ticket;
12-
use datafusion::execution::SessionState;
12+
use datafusion::execution::{SendableRecordBatchStream, SessionState};
1313
use futures::TryStreamExt;
14+
use http::HeaderMap;
1415
use prost::Message;
16+
use std::fmt::Display;
1517
use std::sync::atomic::{AtomicUsize, Ordering};
1618
use std::sync::Arc;
1719
use tokio::sync::OnceCell;
18-
use tonic::metadata::MetadataMap;
1920
use tonic::{Request, Response, Status};
2021

2122
#[derive(Clone, PartialEq, ::prost::Message)]
@@ -41,9 +42,9 @@ pub struct DoGet {
4142
/// TaskData stores state for a single task being executed by this Endpoint. It may be shared
4243
/// by concurrent requests for the same task which execute separate partitions.
4344
pub struct TaskData {
44-
pub(super) state: SessionState,
45+
pub(super) session_state: SessionState,
4546
pub(super) stage: Arc<StageExec>,
46-
///num_partitions_remaining is initialized to the total number of partitions in the task (not
47+
/// `num_partitions_remaining` is initialized to the total number of partitions in the task (not
4748
/// only tasks in the partition group). This is decremented for each request to the endpoint
4849
/// for this task. Once this count is zero, the task is likely complete. The task may not be
4950
/// complete because it's possible that the same partition was retried and this count was
@@ -56,98 +57,78 @@ impl ArrowFlightEndpoint {
5657
&self,
5758
request: Request<Ticket>,
5859
) -> Result<Response<<ArrowFlightEndpoint as FlightService>::DoGetStream>, Status> {
59-
let (metadata, _ext, ticket) = request.into_parts();
60-
let Ticket { ticket } = ticket;
61-
let doget = DoGet::decode(ticket).map_err(|err| {
60+
let (metadata, _ext, body) = request.into_parts();
61+
let doget = DoGet::decode(body.ticket).map_err(|err| {
6262
Status::invalid_argument(format!("Cannot decode DoGet message: {err}"))
6363
})?;
6464

65+
// There's only 1 `StageExec` responsible for all requests that share the same `stage_key`,
66+
// so here we either retrieve the existing one or create a new one if it does not exist.
67+
let (mut session_state, stage) = self
68+
.get_state_and_stage(
69+
doget.stage_key.ok_or_else(missing("stage_key"))?,
70+
doget.stage_proto.ok_or_else(missing("stage_proto"))?,
71+
metadata.clone().into_headers(),
72+
)
73+
.await?;
74+
75+
// Find out which partition group we are executing
6576
let partition = doget.partition as usize;
6677
let task_number = doget.task_number as usize;
67-
let task_data = self.get_state_and_stage(doget, metadata).await?;
68-
69-
let stage = task_data.stage;
70-
let mut state = task_data.state;
71-
72-
// find out which partition group we are executing
73-
let task = stage
74-
.tasks
75-
.get(task_number)
76-
.ok_or(Status::invalid_argument(format!(
77-
"Task number {} not found in stage {}",
78-
task_number,
79-
stage.name()
80-
)))?;
81-
82-
let partition_group = PartitionGroup(task.partition_group.clone());
83-
state.config_mut().set_extension(Arc::new(partition_group));
84-
85-
let inner_plan = stage.plan.clone();
86-
87-
let stream = inner_plan
88-
.execute(partition, state.task_ctx())
78+
let task = stage.tasks.get(task_number).ok_or_else(invalid(format!(
79+
"Task number {task_number} not found in stage {}",
80+
stage.num
81+
)))?;
82+
83+
let cfg = session_state.config_mut();
84+
cfg.set_extension(Arc::new(PartitionGroup(task.partition_group.clone())));
85+
cfg.set_extension(Arc::clone(&stage));
86+
cfg.set_extension(Arc::new(ContextGrpcMetadata(metadata.into_headers())));
87+
88+
// Rather than executing the `StageExec` itself, we want to execute the inner plan instead,
89+
// as executing `StageExec` performs some worker assignation that should have already been
90+
// done in the head stage.
91+
let stream = stage
92+
.plan
93+
.execute(partition, session_state.task_ctx())
8994
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;
9095

91-
let flight_data_stream = FlightDataEncoderBuilder::new()
92-
.with_schema(inner_plan.schema().clone())
93-
.build(stream.map_err(|err| {
94-
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
95-
}));
96-
97-
Ok(Response::new(Box::pin(flight_data_stream.map_err(
98-
|err| match err {
99-
FlightError::Tonic(status) => *status,
100-
_ => Status::internal(format!("Error during flight stream: {err}")),
101-
},
102-
))))
96+
Ok(record_batch_stream_to_response(stream))
10397
}
10498

10599
async fn get_state_and_stage(
106100
&self,
107-
doget: DoGet,
108-
metadata_map: MetadataMap,
109-
) -> Result<TaskData, Status> {
110-
let key = doget
111-
.stage_key
112-
.ok_or(Status::invalid_argument("DoGet is missing the stage key"))?;
113-
let once_stage = self
114-
.stages
101+
key: StageKey,
102+
stage_proto: StageExecProto,
103+
headers: HeaderMap,
104+
) -> Result<(SessionState, Arc<StageExec>), Status> {
105+
let once = self
106+
.task_data_entries
115107
.get_or_init(key.clone(), || Arc::new(OnceCell::<TaskData>::new()));
116108

117-
let stage_data = once_stage
109+
let stage_data = once
118110
.get_or_try_init(|| async {
119-
let stage_proto = doget
120-
.stage_proto
121-
.ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?;
122-
123-
let headers = metadata_map.into_headers();
124-
let mut state = self
111+
let session_state = self
125112
.session_builder
126113
.build_session_state(DistributedSessionBuilderContext {
127114
runtime_env: Arc::clone(&self.runtime),
128-
headers: headers.clone(),
115+
headers,
129116
})
130117
.await
131118
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
132119

133-
let codec = DistributedCodec::new_combined_with_user(state.config());
120+
let codec = DistributedCodec::new_combined_with_user(session_state.config());
134121

135-
let stage = stage_from_proto(stage_proto, &state, self.runtime.as_ref(), &codec)
136-
.map(Arc::new)
122+
let stage = stage_from_proto(stage_proto, &session_state, &self.runtime, &codec)
137123
.map_err(|err| {
138124
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
139125
})?;
140126

141-
// Add the extensions that might be required for ExecutionPlan nodes in the plan
142-
let config = state.config_mut();
143-
config.set_extension(stage.clone());
144-
config.set_extension(Arc::new(ContextGrpcMetadata(headers)));
145-
146127
// Initialize partition count to the number of partitions in the stage
147128
let total_partitions = stage.plan.properties().partitioning.partition_count();
148129
Ok::<_, Status>(TaskData {
149-
state,
150-
stage,
130+
session_state,
131+
stage: Arc::new(stage),
151132
num_partitions_remaining: Arc::new(AtomicUsize::new(total_partitions)),
152133
})
153134
})
@@ -158,13 +139,37 @@ impl ArrowFlightEndpoint {
158139
.num_partitions_remaining
159140
.fetch_sub(1, Ordering::SeqCst);
160141
if remaining_partitions <= 1 {
161-
self.stages.remove(key.clone());
142+
self.task_data_entries.remove(key);
162143
}
163144

164-
Ok(stage_data.clone())
145+
Ok((stage_data.session_state.clone(), stage_data.stage.clone()))
165146
}
166147
}
167148

149+
fn missing(field: &'static str) -> impl FnOnce() -> Status {
150+
move || Status::invalid_argument(format!("Missing field '{field}'"))
151+
}
152+
153+
fn invalid(msg: impl Display) -> impl FnOnce() -> Status {
154+
move || Status::invalid_argument(msg.to_string())
155+
}
156+
157+
fn record_batch_stream_to_response(
158+
stream: SendableRecordBatchStream,
159+
) -> Response<<ArrowFlightEndpoint as FlightService>::DoGetStream> {
160+
let flight_data_stream =
161+
FlightDataEncoderBuilder::new()
162+
.with_schema(stream.schema().clone())
163+
.build(stream.map_err(|err| {
164+
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
165+
}));
166+
167+
Response::new(Box::pin(flight_data_stream.map_err(|err| match err {
168+
FlightError::Tonic(status) => *status,
169+
_ => Status::internal(format!("Error during flight stream: {err}")),
170+
})))
171+
}
172+
168173
#[cfg(test)]
169174
mod tests {
170175
use super::*;
@@ -262,28 +267,28 @@ mod tests {
262267
}
263268

264269
// Check that the endpoint has not evicted any task states.
265-
assert_eq!(endpoint.stages.len(), num_tasks);
270+
assert_eq!(endpoint.task_data_entries.len(), num_tasks);
266271

267272
// Run the last partition of task 0. Any partition number works. Verify that the task state
268273
// is evicted because all partitions have been processed.
269274
let result = do_get(1, 0, task_keys[0].clone()).await;
270275
assert!(result.is_ok());
271-
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
276+
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
272277
assert_eq!(stored_stage_keys.len(), 2);
273278
assert!(stored_stage_keys.contains(&task_keys[1]));
274279
assert!(stored_stage_keys.contains(&task_keys[2]));
275280

276281
// Run the last partition of task 1.
277282
let result = do_get(1, 1, task_keys[1].clone()).await;
278283
assert!(result.is_ok());
279-
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
284+
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
280285
assert_eq!(stored_stage_keys.len(), 1);
281286
assert!(stored_stage_keys.contains(&task_keys[2]));
282287

283288
// Run the last partition of the last task.
284289
let result = do_get(1, 2, task_keys[2].clone()).await;
285290
assert!(result.is_ok());
286-
let stored_stage_keys = endpoint.stages.keys().collect::<Vec<StageKey>>();
291+
let stored_stage_keys = endpoint.task_data_entries.keys().collect::<Vec<StageKey>>();
287292
assert_eq!(stored_stage_keys.len(), 0);
288293
}
289294

src/flight_service/service.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ pub struct StageKey {
3030

3131
pub struct ArrowFlightEndpoint {
3232
pub(super) runtime: Arc<RuntimeEnv>,
33-
#[allow(clippy::type_complexity)]
34-
pub(super) stages: TTLMap<StageKey, Arc<OnceCell<TaskData>>>,
33+
pub(super) task_data_entries: TTLMap<StageKey, Arc<OnceCell<TaskData>>>,
3534
pub(super) session_builder: Arc<dyn DistributedSessionBuilder + Send + Sync>,
3635
}
3736

@@ -42,7 +41,7 @@ impl ArrowFlightEndpoint {
4241
let ttl_map = TTLMap::try_new(TTLMapConfig::default())?;
4342
Ok(Self {
4443
runtime: Arc::new(RuntimeEnv::default()),
45-
stages: ttl_map,
44+
task_data_entries: ttl_map,
4645
session_builder: Arc::new(session_builder),
4746
})
4847
}

0 commit comments

Comments
 (0)