Skip to content

Commit abaefd9

Browse files
committed
Refactor do_get.rs
1 parent 1cb3e65 commit abaefd9

File tree

1 file changed

+39
-59
lines changed

1 file changed

+39
-59
lines changed

src/flight_service/do_get.rs

Lines changed: 39 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,12 @@ use arrow_flight::encode::FlightDataEncoderBuilder;
99
use arrow_flight::error::FlightError;
1010
use arrow_flight::flight_service_server::FlightService;
1111
use arrow_flight::Ticket;
12-
use datafusion::execution::{SendableRecordBatchStream, SessionState};
12+
use datafusion::execution::SendableRecordBatchStream;
1313
use futures::TryStreamExt;
14-
use http::HeaderMap;
1514
use prost::Message;
1615
use std::fmt::Display;
1716
use std::sync::atomic::{AtomicUsize, Ordering};
1817
use std::sync::Arc;
19-
use tokio::sync::OnceCell;
2018
use tonic::{Request, Response, Status};
2119

2220
#[derive(Clone, PartialEq, ::prost::Message)]
@@ -42,7 +40,6 @@ pub struct DoGet {
4240
/// TaskData stores state for a single task being executed by this Endpoint. It may be shared
4341
/// by concurrent requests for the same task which execute separate partitions.
4442
pub struct TaskData {
45-
pub(super) session_state: SessionState,
4643
pub(super) stage: Arc<StageExec>,
4744
/// `num_partitions_remaining` is initialized to the total number of partitions in the task (not
4845
/// only tasks in the partition group). This is decremented for each request to the endpoint
@@ -62,15 +59,47 @@ impl ArrowFlightEndpoint {
6259
Status::invalid_argument(format!("Cannot decode DoGet message: {err}"))
6360
})?;
6461

62+
let mut session_state = self
63+
.session_builder
64+
.build_session_state(DistributedSessionBuilderContext {
65+
runtime_env: Arc::clone(&self.runtime),
66+
headers: metadata.clone().into_headers(),
67+
})
68+
.await
69+
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
70+
71+
let codec = DistributedCodec::new_combined_with_user(session_state.config());
72+
6573
// There's only 1 `StageExec` responsible for all requests that share the same `stage_key`,
6674
// so here we either retrieve the existing one or create a new one if it does not exist.
67-
let (mut session_state, stage) = self
68-
.get_state_and_stage(
69-
doget.stage_key.ok_or_else(missing("stage_key"))?,
70-
doget.stage_proto.ok_or_else(missing("stage_proto"))?,
71-
metadata.clone().into_headers(),
72-
)
75+
let key = doget.stage_key.ok_or_else(missing("stage_key"))?;
76+
let once = self
77+
.task_data_entries
78+
.get_or_init(key.clone(), Default::default);
79+
80+
let stage_data = once
81+
.get_or_try_init(|| async {
82+
let stage_proto = doget.stage_proto.ok_or_else(missing("stage_proto"))?;
83+
let stage = stage_from_proto(stage_proto, &session_state, &self.runtime, &codec)
84+
.map_err(|err| {
85+
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
86+
})?;
87+
88+
// Initialize partition count to the number of partitions in the stage
89+
let total_partitions = stage.plan.properties().partitioning.partition_count();
90+
Ok::<_, Status>(TaskData {
91+
stage: Arc::new(stage),
92+
num_partitions_remaining: Arc::new(AtomicUsize::new(total_partitions)),
93+
})
94+
})
7395
.await?;
96+
let stage = Arc::clone(&stage_data.stage);
97+
let num_partitions_remaining = Arc::clone(&stage_data.num_partitions_remaining);
98+
99+
// If all the partitions are done, remove the stage from the cache.
100+
if num_partitions_remaining.fetch_sub(1, Ordering::SeqCst) <= 1 {
101+
self.task_data_entries.remove(key);
102+
}
74103

75104
// Find out which partition group we are executing
76105
let partition = doget.partition as usize;
@@ -95,55 +124,6 @@ impl ArrowFlightEndpoint {
95124

96125
Ok(record_batch_stream_to_response(stream))
97126
}
98-
99-
async fn get_state_and_stage(
100-
&self,
101-
key: StageKey,
102-
stage_proto: StageExecProto,
103-
headers: HeaderMap,
104-
) -> Result<(SessionState, Arc<StageExec>), Status> {
105-
let once = self
106-
.task_data_entries
107-
.get_or_init(key.clone(), || Arc::new(OnceCell::<TaskData>::new()));
108-
109-
let stage_data = once
110-
.get_or_try_init(|| async {
111-
let session_state = self
112-
.session_builder
113-
.build_session_state(DistributedSessionBuilderContext {
114-
runtime_env: Arc::clone(&self.runtime),
115-
headers,
116-
})
117-
.await
118-
.map_err(|err| datafusion_error_to_tonic_status(&err))?;
119-
120-
let codec = DistributedCodec::new_combined_with_user(session_state.config());
121-
122-
let stage = stage_from_proto(stage_proto, &session_state, &self.runtime, &codec)
123-
.map_err(|err| {
124-
Status::invalid_argument(format!("Cannot decode stage proto: {err}"))
125-
})?;
126-
127-
// Initialize partition count to the number of partitions in the stage
128-
let total_partitions = stage.plan.properties().partitioning.partition_count();
129-
Ok::<_, Status>(TaskData {
130-
session_state,
131-
stage: Arc::new(stage),
132-
num_partitions_remaining: Arc::new(AtomicUsize::new(total_partitions)),
133-
})
134-
})
135-
.await?;
136-
137-
// If all the partitions are done, remove the stage from the cache.
138-
let remaining_partitions = stage_data
139-
.num_partitions_remaining
140-
.fetch_sub(1, Ordering::SeqCst);
141-
if remaining_partitions <= 1 {
142-
self.task_data_entries.remove(key);
143-
}
144-
145-
Ok((stage_data.session_state.clone(), stage_data.stage.clone()))
146-
}
147127
}
148128

149129
fn missing(field: &'static str) -> impl FnOnce() -> Status {

0 commit comments

Comments
 (0)