Skip to content

Commit 2df3467

Browse files
flight_service: add TrailingFlightDataStream (#157)
This new type of streams allows you to "chain" streams together. This will be used to "chain" FlightData containing metrics to the stream corresponding to the last executed partition in a task. Streams already offer a `chain()` API but it lacks a feature we need - a callback to decide whether or not to collect and append metrics. When a stream is complete, as a part of the proposed metrics protocol, the ArrowFlightEndpoint needs to check if this is the last stream/partition. This synchronization point between N streams for N partitions will have to be managed in some shared state between the stream and the ArrowFlightEndpoint. In this callback, we can check if this is the last partition. If so, we can collect metrics and send them in the trailing stream. For more details, see the draft implementation [here](https://github.com/datafusion-contrib/datafusion-distributed/pull/139/files#diff-fa3e517ceea7f93b2d50873bcdf7f48f6110a5cf8b25a4a8df338f7d71dc6fdb) One alternative option considered was to have the callback generate `FlightData` which can be sent on the stream. Ultimately, it seemed cleaner to just return another stream. Informs: #123
1 parent 64920af commit 2df3467

File tree

2 files changed

+237
-1
lines changed

2 files changed

+237
-1
lines changed

src/flight_service/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mod do_get;
22
mod service;
33
mod session_builder;
4-
4+
mod trailing_flight_data_stream;
55
pub(crate) use do_get::DoGet;
66

77
pub use service::ArrowFlightEndpoint;
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
use arrow_flight::{FlightData, error::FlightError};
2+
use futures::stream::Stream;
3+
use pin_project::pin_project;
4+
use std::pin::Pin;
5+
use std::task::{Context, Poll};
6+
use tokio::pin;
7+
8+
/// TrailingFlightDataStream - wraps a FlightData stream. It calls the `on_complete` closure when the stream is finished.
9+
/// If the closure returns a new stream, it will be appended to the original stream and consumed.
10+
#[pin_project]
11+
pub struct TrailingFlightDataStream<S, F>
12+
where
13+
S: Stream<Item = Result<FlightData, FlightError>> + Send,
14+
F: FnOnce() -> Result<Option<S>, FlightError>,
15+
{
16+
#[pin]
17+
inner: S,
18+
on_complete: Option<F>,
19+
#[pin]
20+
trailing_stream: Option<S>,
21+
}
22+
23+
impl<S, F> TrailingFlightDataStream<S, F>
24+
where
25+
S: Stream<Item = Result<FlightData, FlightError>> + Send,
26+
F: FnOnce() -> Result<Option<S>, FlightError>,
27+
{
28+
// TODO: remove
29+
#[allow(dead_code)]
30+
pub fn new(on_complete: F, inner: S) -> Self {
31+
Self {
32+
inner,
33+
on_complete: Some(on_complete),
34+
trailing_stream: None,
35+
}
36+
}
37+
}
38+
39+
impl<S, F> Stream for TrailingFlightDataStream<S, F>
40+
where
41+
S: Stream<Item = Result<FlightData, FlightError>> + Send,
42+
F: FnOnce() -> Result<Option<S>, FlightError>,
43+
{
44+
type Item = Result<FlightData, FlightError>;
45+
46+
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
47+
let mut this = self.as_mut().project();
48+
49+
match this.inner.poll_next(cx) {
50+
Poll::Ready(Some(Ok(flight_data))) => Poll::Ready(Some(Ok(flight_data))),
51+
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
52+
Poll::Ready(None) => {
53+
if let Some(trailing_stream) = this.trailing_stream.as_mut().as_pin_mut() {
54+
return trailing_stream.poll_next(cx);
55+
}
56+
if let Some(on_complete) = this.on_complete.take() {
57+
if let Some(trailing_stream) = on_complete()? {
58+
this.trailing_stream.set(Some(trailing_stream));
59+
return self.poll_next(cx);
60+
}
61+
}
62+
Poll::Ready(None)
63+
}
64+
Poll::Pending => Poll::Pending,
65+
}
66+
}
67+
}
68+
69+
#[cfg(test)]
70+
mod tests {
71+
use super::*;
72+
use arrow::array::{Array, Int32Array, StringArray};
73+
use arrow::datatypes::{DataType, Field, Schema};
74+
use arrow::record_batch::RecordBatch;
75+
use arrow_flight::FlightData;
76+
use arrow_flight::decode::FlightRecordBatchStream;
77+
use arrow_flight::encode::FlightDataEncoderBuilder;
78+
use futures::stream::{self, StreamExt};
79+
use std::sync::Arc;
80+
81+
fn create_trailing_flight_data_stream(
82+
name_array: StringArray,
83+
value_array: Int32Array,
84+
) -> Pin<Box<dyn Stream<Item = Result<FlightData, FlightError>> + Send>> {
85+
create_flight_data_stream_inner(name_array, value_array, true)
86+
}
87+
88+
fn create_flight_data_stream(
89+
name_array: StringArray,
90+
value_array: Int32Array,
91+
) -> Pin<Box<dyn Stream<Item = Result<FlightData, FlightError>> + Send>> {
92+
create_flight_data_stream_inner(name_array, value_array, false)
93+
}
94+
95+
// Creates a stream of RecordBatches.
96+
fn create_flight_data_stream_inner(
97+
name_array: StringArray,
98+
value_array: Int32Array,
99+
is_trailing: bool,
100+
) -> Pin<Box<dyn Stream<Item = Result<FlightData, FlightError>> + Send>> {
101+
assert_eq!(
102+
name_array.len(),
103+
value_array.len(),
104+
"StringArray and Int32Array must have equal lengths"
105+
);
106+
107+
let schema = Arc::new(Schema::new(vec![
108+
Field::new("name", DataType::Utf8, false),
109+
Field::new("value", DataType::Int32, false),
110+
]));
111+
112+
let batches: Vec<RecordBatch> = (0..name_array.len())
113+
.map(|i| {
114+
let name_slice = name_array.slice(i, 1);
115+
let value_slice = value_array.slice(i, 1);
116+
117+
RecordBatch::try_new(
118+
schema.clone(),
119+
vec![Arc::new(name_slice), Arc::new(value_slice)],
120+
)
121+
.unwrap()
122+
})
123+
.collect();
124+
125+
let batch_stream = futures::stream::iter(batches.into_iter().map(Ok));
126+
let flight_stream = FlightDataEncoderBuilder::new()
127+
.with_schema(schema)
128+
.build(batch_stream);
129+
130+
// By default, this encoder will emit a schema message as the first message in the stream.
131+
// Since we are concatenating streams, we need to drop the schema message from the trailing stream.
132+
if is_trailing {
133+
// Skip the schema message
134+
return Box::pin(flight_stream.skip(1));
135+
}
136+
Box::pin(flight_stream)
137+
}
138+
139+
#[tokio::test]
140+
async fn test_basic_streaming_functionality() {
141+
let name_array = StringArray::from(vec!["a", "b", "c"]);
142+
let value_array = Int32Array::from(vec![1, 2, 3]);
143+
let inner_stream = create_flight_data_stream(name_array, value_array);
144+
145+
let name_array = StringArray::from(vec!["d", "e", "f"]);
146+
let value_array = Int32Array::from(vec![5, 6, 7]);
147+
let trailing_stream = create_trailing_flight_data_stream(name_array, value_array);
148+
149+
let on_complete = || Ok(Some(trailing_stream));
150+
151+
let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream);
152+
let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream)
153+
.collect::<Vec<Result<RecordBatch, FlightError>>>()
154+
.await;
155+
156+
assert_eq!(record_batches.len(), 6);
157+
assert!(record_batches.iter().all(|batch| batch.is_ok()));
158+
assert_eq!(
159+
record_batches
160+
.iter()
161+
.map(|batch| batch
162+
.as_ref()
163+
.unwrap()
164+
.column(0)
165+
.as_any()
166+
.downcast_ref::<StringArray>()
167+
.unwrap()
168+
.value(0))
169+
.collect::<Vec<_>>(),
170+
vec!["a", "b", "c", "d", "e", "f"]
171+
);
172+
}
173+
174+
#[tokio::test]
175+
async fn test_error_handling_in_inner_stream() {
176+
let mut stream =
177+
create_flight_data_stream(StringArray::from(vec!["item1"]), Int32Array::from(vec![1]));
178+
let schema_message = stream.next().await.unwrap().unwrap();
179+
let flight_data = stream.next().await.unwrap().unwrap();
180+
let data = vec![
181+
Ok(schema_message),
182+
Ok(flight_data),
183+
Err(FlightError::ExternalError(Box::new(std::io::Error::new(
184+
std::io::ErrorKind::Other,
185+
"test error",
186+
)))),
187+
];
188+
let inner_stream = stream::iter(data);
189+
let on_complete = || Ok(None);
190+
let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream);
191+
let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream)
192+
.collect::<Vec<Result<RecordBatch, FlightError>>>()
193+
.await;
194+
195+
assert_eq!(record_batches.len(), 2);
196+
assert!(record_batches[0].is_ok());
197+
assert!(record_batches[1].is_err());
198+
}
199+
200+
#[tokio::test]
201+
async fn test_error_handling_in_on_complete_callback() {
202+
let name_array = StringArray::from(vec!["item1"]);
203+
let value_array = Int32Array::from(vec![1]);
204+
let inner_stream = create_flight_data_stream(name_array, value_array);
205+
206+
let on_complete = || -> Result<Option<_>, FlightError> {
207+
Err(FlightError::ExternalError(Box::new(std::io::Error::new(
208+
std::io::ErrorKind::Other,
209+
"callback error",
210+
))))
211+
};
212+
213+
let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream);
214+
let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream)
215+
.collect::<Vec<Result<RecordBatch, FlightError>>>()
216+
.await;
217+
assert_eq!(record_batches.len(), 2);
218+
assert!(record_batches[0].is_ok());
219+
assert!(record_batches[1].is_err());
220+
}
221+
222+
#[tokio::test]
223+
async fn test_stream_with_no_trailer() {
224+
let inner_stream = create_flight_data_stream(
225+
StringArray::from(vec!["item1"] as Vec<&str>),
226+
Int32Array::from(vec![1] as Vec<i32>),
227+
);
228+
let on_complete = || Ok(None);
229+
let trailing_stream = TrailingFlightDataStream::new(on_complete, inner_stream);
230+
let record_batches = FlightRecordBatchStream::new_from_flight_data(trailing_stream)
231+
.collect::<Vec<Result<RecordBatch, FlightError>>>()
232+
.await;
233+
assert_eq!(record_batches.len(), 1);
234+
assert!(record_batches[0].is_ok());
235+
}
236+
}

0 commit comments

Comments
 (0)