Skip to content

Commit d3b46ab

Browse files
authored
Fix early drop stateful nodes (#159)
* Add issue reproducer * Hold a reference to the stage for as long as the stream is executed * Add comment for the test
1 parent 3337ce8 commit d3b46ab

File tree

4 files changed

+383
-1
lines changed

4 files changed

+383
-1
lines changed

src/common/callback_stream.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
use futures::Stream;
2+
use pin_project::{pin_project, pinned_drop};
3+
use std::fmt::Display;
4+
use std::pin::Pin;
5+
use std::task::{Context, Poll};
6+
7+
/// The reason why the stream ended:
8+
/// - [CallbackStreamEndReason::Finished] if it finished gracefully
9+
/// - [CallbackStreamEndReason::Aborted] if it was abandoned.
10+
#[derive(Debug)]
11+
pub enum CallbackStreamEndReason {
12+
/// The stream finished gracefully.
13+
Finished,
14+
/// The stream was abandoned.
15+
Aborted,
16+
}
17+
18+
impl Display for CallbackStreamEndReason {
19+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20+
write!(f, "{:?}", self)
21+
}
22+
}
23+
24+
/// Stream that executes a callback when it is fully consumed or gets cancelled.
25+
#[pin_project(PinnedDrop)]
26+
pub struct CallbackStream<S, F>
27+
where
28+
S: Stream,
29+
F: FnOnce(CallbackStreamEndReason),
30+
{
31+
#[pin]
32+
stream: S,
33+
callback: Option<F>,
34+
}
35+
36+
impl<S, F> Stream for CallbackStream<S, F>
37+
where
38+
S: Stream,
39+
F: FnOnce(CallbackStreamEndReason),
40+
{
41+
type Item = S::Item;
42+
43+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44+
let this = self.project();
45+
46+
match this.stream.poll_next(cx) {
47+
Poll::Ready(None) => {
48+
// Stream is fully consumed, execute the callback
49+
if let Some(callback) = this.callback.take() {
50+
callback(CallbackStreamEndReason::Finished);
51+
}
52+
Poll::Ready(None)
53+
}
54+
other => other,
55+
}
56+
}
57+
}
58+
59+
#[pinned_drop]
60+
impl<S, F> PinnedDrop for CallbackStream<S, F>
61+
where
62+
S: Stream,
63+
F: FnOnce(CallbackStreamEndReason),
64+
{
65+
fn drop(self: Pin<&mut Self>) {
66+
let this = self.project();
67+
if let Some(callback) = this.callback.take() {
68+
callback(CallbackStreamEndReason::Aborted);
69+
}
70+
}
71+
}
72+
73+
/// Wrap a stream with a callback that will be executed when the stream is fully
74+
/// consumed or gets canceled.
75+
pub fn with_callback<S, F>(stream: S, callback: F) -> CallbackStream<S, F>
76+
where
77+
S: Stream,
78+
F: FnOnce(CallbackStreamEndReason) + Send + 'static,
79+
{
80+
CallbackStream {
81+
stream,
82+
callback: Some(callback),
83+
}
84+
}

src/common/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
mod callback_stream;
12
mod composed_extension_codec;
23
mod partitioning;
34
#[allow(unused)]
45
pub mod ttl_map;
56

7+
pub(crate) use callback_stream::with_callback;
68
pub(crate) use composed_extension_codec::ComposedPhysicalExtensionCodec;
79
pub(crate) use partitioning::{scale_partitioning, scale_partitioning_props};

src/flight_service/do_get.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::common::with_callback;
12
use crate::config_extension_ext::ContextGrpcMetadata;
23
use crate::execution_plans::{DistributedTaskContext, StageExec};
34
use crate::flight_service::service::ArrowFlightEndpoint;
@@ -11,6 +12,7 @@ use arrow_flight::error::FlightError;
1112
use arrow_flight::flight_service_server::FlightService;
1213
use datafusion::common::exec_datafusion_err;
1314
use datafusion::execution::SendableRecordBatchStream;
15+
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
1416
use futures::TryStreamExt;
1517
use prost::Message;
1618
use std::sync::Arc;
@@ -126,7 +128,17 @@ impl ArrowFlightEndpoint {
126128
.execute(doget.target_partition as usize, session_state.task_ctx())
127129
.map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?;
128130

129-
Ok(record_batch_stream_to_response(stream))
131+
let schema = stream.schema();
132+
let stream = with_callback(stream, move |_| {
133+
// We need to hold a reference to the plan for at least as long as the stream is
134+
// execution. Some plans might store state necessary for the stream to work, and
135+
// dropping the plan early could drop this state too soon.
136+
let _ = stage.plan;
137+
});
138+
139+
Ok(record_batch_stream_to_response(Box::pin(
140+
RecordBatchStreamAdapter::new(schema, stream),
141+
)))
130142
}
131143
}
132144

0 commit comments

Comments
 (0)