Skip to content

Commit 5e670b2

Browse files
committed
Use map_last_stream.rs in favor of callback_stream.rs
1 parent 624228b commit 5e670b2

File tree

4 files changed

+106
-192
lines changed

4 files changed

+106
-192
lines changed

src/common/callback_stream.rs

Lines changed: 0 additions & 84 deletions
This file was deleted.

src/common/map_last_stream.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
use futures::{Stream, StreamExt, stream};
2+
use std::task::Poll;
3+
4+
/// Maps the last element of the provided stream.
5+
pub(crate) fn map_last_stream<T>(
6+
mut input: impl Stream<Item = T> + Unpin,
7+
map_f: impl FnOnce(T) -> T,
8+
) -> impl Stream<Item = T> + Unpin {
9+
let mut final_closure = Some(map_f);
10+
11+
// this is used to peek the new value so that we can map upon emitting the last message
12+
let mut current_value = None;
13+
14+
stream::poll_fn(move |cx| match futures::ready!(input.poll_next_unpin(cx)) {
15+
Some(new_val) => {
16+
match current_value.take() {
17+
// This is the first value, so we store it and repoll to get the next value
18+
None => {
19+
current_value = Some(new_val);
20+
cx.waker().wake_by_ref();
21+
Poll::Pending
22+
}
23+
24+
Some(existing) => {
25+
current_value = Some(new_val);
26+
27+
Poll::Ready(Some(existing))
28+
}
29+
}
30+
}
31+
// this is our last value, so we map it using the user provided closure
32+
None => match current_value.take() {
33+
Some(existing) => {
34+
// make sure we wake ourselves to finish the stream
35+
cx.waker().wake_by_ref();
36+
37+
if let Some(closure) = final_closure.take() {
38+
Poll::Ready(Some(closure(existing)))
39+
} else {
40+
unreachable!("the closure is only executed once")
41+
}
42+
}
43+
None => Poll::Ready(None),
44+
},
45+
})
46+
}
47+
48+
#[cfg(test)]
49+
mod tests {
50+
use super::*;
51+
use futures::stream;
52+
53+
#[tokio::test]
54+
async fn test_map_last_stream_empty_stream() {
55+
let input = stream::empty::<i32>();
56+
let mapped = map_last_stream(input, |x| x + 10);
57+
let result: Vec<i32> = mapped.collect().await;
58+
assert_eq!(result, Vec::<i32>::new());
59+
}
60+
61+
#[tokio::test]
62+
async fn test_map_last_stream_single_element() {
63+
let input = stream::iter(vec![5]);
64+
let mapped = map_last_stream(input, |x| x * 2);
65+
let result: Vec<i32> = mapped.collect().await;
66+
assert_eq!(result, vec![10]);
67+
}
68+
69+
#[tokio::test]
70+
async fn test_map_last_stream_multiple_elements() {
71+
let input = stream::iter(vec![1, 2, 3, 4]);
72+
let mapped = map_last_stream(input, |x| x + 100);
73+
let result: Vec<i32> = mapped.collect().await;
74+
assert_eq!(result, vec![1, 2, 3, 104]); // Only the last element is transformed
75+
}
76+
77+
#[tokio::test]
78+
async fn test_map_last_stream_preserves_order() {
79+
let input = stream::iter(vec![10, 20, 30, 40, 50]);
80+
let mapped = map_last_stream(input, |x| x - 50);
81+
let result: Vec<i32> = mapped.collect().await;
82+
assert_eq!(result, vec![10, 20, 30, 40, 0]); // Last element: 50 - 50 = 0
83+
}
84+
}

src/common/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
mod callback_stream;
21
mod composed_extension_codec;
2+
mod map_last_stream;
33
mod partitioning;
44
#[allow(unused)]
55
pub mod ttl_map;
66

7-
pub(crate) use callback_stream::with_callback;
87
pub(crate) use composed_extension_codec::ComposedPhysicalExtensionCodec;
8+
pub(crate) use map_last_stream::map_last_stream;
99
pub(crate) use partitioning::{scale_partitioning, scale_partitioning_props};

src/flight_service/do_get.rs

Lines changed: 20 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::common::with_callback;
1+
use crate::common::map_last_stream;
22
use crate::config_extension_ext::ContextGrpcMetadata;
33
use crate::execution_plans::{DistributedTaskContext, StageExec};
44
use crate::flight_service::service::ArrowFlightEndpoint;
@@ -17,15 +17,11 @@ use arrow_flight::flight_service_server::FlightService;
1717
use bytes::Bytes;
1818

1919
use datafusion::common::exec_datafusion_err;
20-
use datafusion::execution::SendableRecordBatchStream;
21-
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
2220
use datafusion::prelude::SessionContext;
23-
use futures::stream;
24-
use futures::{StreamExt, TryStreamExt};
21+
use futures::TryStreamExt;
2522
use prost::Message;
2623
use std::sync::Arc;
2724
use std::sync::atomic::{AtomicUsize, Ordering};
28-
use std::task::Poll;
2925
use tonic::{Request, Response, Status};
3026

3127
#[derive(Clone, PartialEq, ::prost::Message)]
@@ -132,115 +128,33 @@ impl ArrowFlightEndpoint {
132128
.execute(doget.target_partition as usize, session_state.task_ctx())
133129
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;
134130

135-
let schema = stream.schema();
136-
137-
// TODO: We don't need to do this since the stage / plan is captured again by the
138-
// TrailingFlightDataStream. However, we will eventuall only use the TrailingFlightDataStream
139-
// if we are running an `explain (analyze)` command. We should update this section
140-
// to only use one or the other - not both.
141-
let plan_capture = stage.plan.clone();
142-
let stream = with_callback(stream, move |_| {
143-
// We need to hold a reference to the plan for at least as long as the stream is
144-
// execution. Some plans might store state necessary for the stream to work, and
145-
// dropping the plan early could drop this state too soon.
146-
let _ = plan_capture;
131+
let stream = FlightDataEncoderBuilder::new()
132+
.with_schema(stream.schema().clone())
133+
.build(stream.map_err(|err| {
134+
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
135+
}));
136+
137+
let task_data_entries = Arc::clone(&self.task_data_entries);
138+
let num_partitions_remaining = Arc::clone(&stage_data.num_partitions_remaining);
139+
140+
let stream = map_last_stream(stream, move |last| {
141+
if num_partitions_remaining.fetch_sub(1, Ordering::SeqCst) == 1 {
142+
task_data_entries.remove(key.clone());
143+
}
144+
last.and_then(|el| collect_and_create_metrics_flight_data(key, stage, el))
147145
});
148146

149-
let record_batch_stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream));
150-
let task_data_capture = self.task_data_entries.clone();
151-
Ok(flight_stream_from_record_batch_stream(
152-
key.clone(),
153-
stage_data.clone(),
154-
move || {
155-
task_data_capture.remove(key.clone());
156-
},
157-
record_batch_stream,
158-
))
147+
Ok(Response::new(Box::pin(stream.map_err(|err| match err {
148+
FlightError::Tonic(status) => *status,
149+
_ => Status::internal(format!("Error during flight stream: {err}")),
150+
}))))
159151
}
160152
}
161153

162154
fn missing(field: &'static str) -> impl FnOnce() -> Status {
163155
move || Status::invalid_argument(format!("Missing field '{field}'"))
164156
}
165157

166-
/// Creates a tonic response from a stream of record batches. Handles
167-
/// - RecordBatch to flight conversion partition tracking, stage eviction, and trailing metrics.
168-
fn flight_stream_from_record_batch_stream(
169-
stage_key: StageKey,
170-
stage_data: TaskData,
171-
evict_stage: impl FnOnce() + Send + 'static,
172-
stream: SendableRecordBatchStream,
173-
) -> Response<<ArrowFlightEndpoint as FlightService>::DoGetStream> {
174-
let mut flight_data_stream =
175-
FlightDataEncoderBuilder::new()
176-
.with_schema(stream.schema().clone())
177-
.build(stream.map_err(|err| {
178-
FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err)))
179-
}));
180-
181-
// executed once when the stream ends
182-
// decorates the last flight data with our metrics
183-
let mut final_closure = Some(move |last_flight_data| {
184-
if stage_data
185-
.num_partitions_remaining
186-
.fetch_sub(1, Ordering::SeqCst)
187-
== 1
188-
{
189-
evict_stage();
190-
191-
collect_and_create_metrics_flight_data(stage_key, stage_data.stage, last_flight_data)
192-
} else {
193-
Ok(last_flight_data)
194-
}
195-
});
196-
197-
// this is used to peek the new value
198-
// so that we can add our metrics to the last flight data
199-
let mut current_value = None;
200-
201-
let stream =
202-
stream::poll_fn(
203-
move |cx| match futures::ready!(flight_data_stream.poll_next_unpin(cx)) {
204-
Some(Ok(new_val)) => {
205-
match current_value.take() {
206-
// This is the first value, so we store it and repoll to get the next value
207-
None => {
208-
current_value = Some(new_val);
209-
cx.waker().wake_by_ref();
210-
Poll::Pending
211-
}
212-
213-
Some(existing) => {
214-
current_value = Some(new_val);
215-
216-
Poll::Ready(Some(Ok(existing)))
217-
}
218-
}
219-
}
220-
// this is our last value, so we add our metrics to this flight data
221-
None => match current_value.take() {
222-
Some(existing) => {
223-
// make sure we wake ourselves to finish the stream
224-
cx.waker().wake_by_ref();
225-
226-
if let Some(closure) = final_closure.take() {
227-
Poll::Ready(Some(closure(existing)))
228-
} else {
229-
unreachable!("the closure is only executed once")
230-
}
231-
}
232-
None => Poll::Ready(None),
233-
},
234-
err => Poll::Ready(err),
235-
},
236-
);
237-
238-
Response::new(Box::pin(stream.map_err(|err| match err {
239-
FlightError::Tonic(status) => *status,
240-
_ => Status::internal(format!("Error during flight stream: {err}")),
241-
})))
242-
}
243-
244158
/// Collects metrics from the provided stage and includes it in the flight data
245159
fn collect_and_create_metrics_flight_data(
246160
stage_key: StageKey,

0 commit comments

Comments
 (0)