Skip to content

Commit 4a6482d

Browse files
committed
Add integration test for error propagation
1 parent 35ba673 commit 4a6482d

File tree

1 file changed

+172
-0
lines changed

1 file changed

+172
-0
lines changed

tests/error_propagation.rs

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#[allow(dead_code)]
2+
mod common;
3+
4+
#[cfg(test)]
5+
mod tests {
6+
use crate::common::localhost::start_localhost_context;
7+
use datafusion::arrow::datatypes::{DataType, Field, Schema};
8+
use datafusion::error::DataFusionError;
9+
use datafusion::execution::{
10+
FunctionRegistry, SendableRecordBatchStream, SessionStateBuilder, TaskContext,
11+
};
12+
use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
13+
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
14+
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
15+
use datafusion::physical_plan::{
16+
execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties,
17+
};
18+
use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder};
19+
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
20+
use datafusion_proto::protobuf::proto_error;
21+
use futures::{stream, TryStreamExt};
22+
use prost::Message;
23+
use std::any::Any;
24+
use std::error::Error;
25+
use std::fmt::Formatter;
26+
use std::sync::Arc;
27+
28+
#[tokio::test]
29+
async fn test_error_propagation() -> Result<(), Box<dyn Error>> {
30+
#[derive(Clone)]
31+
struct CustomSessionBuilder;
32+
impl SessionBuilder for CustomSessionBuilder {
33+
fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder {
34+
let codec: Arc<dyn PhysicalExtensionCodec> = Arc::new(ErrorExecCodec);
35+
let config = builder.config().get_or_insert_default();
36+
config.set_extension(Arc::new(codec));
37+
builder
38+
}
39+
}
40+
let (ctx, _guard) =
41+
start_localhost_context([50050, 50051, 50053], CustomSessionBuilder).await;
42+
43+
let codec: Arc<dyn PhysicalExtensionCodec> = Arc::new(ErrorExecCodec);
44+
ctx.state_ref()
45+
.write()
46+
.config_mut()
47+
.set_extension(Arc::new(codec));
48+
49+
let mut plan: Arc<dyn ExecutionPlan> = Arc::new(ErrorExec::new("something failed"));
50+
51+
for size in [1, 2, 3] {
52+
plan = Arc::new(ArrowFlightReadExec::new(
53+
plan,
54+
Partitioning::RoundRobinBatch(size),
55+
));
56+
}
57+
58+
let stream = execute_stream(plan, ctx.task_ctx())?;
59+
60+
let Err(err) = stream.try_collect::<Vec<_>>().await else {
61+
panic!("Should have failed")
62+
};
63+
assert_eq!(
64+
DataFusionError::Execution("something failed".to_string()).to_string(),
65+
err.to_string()
66+
);
67+
68+
Ok(())
69+
}
70+
71+
#[derive(Debug)]
72+
pub struct ErrorExec {
73+
msg: String,
74+
plan_properties: PlanProperties,
75+
}
76+
77+
impl ErrorExec {
78+
fn new(msg: &str) -> Self {
79+
let schema = Schema::new(vec![Field::new("numbers", DataType::Int64, false)]);
80+
Self {
81+
msg: msg.to_string(),
82+
plan_properties: PlanProperties::new(
83+
EquivalenceProperties::new(Arc::new(schema)),
84+
Partitioning::UnknownPartitioning(1),
85+
EmissionType::Incremental,
86+
Boundedness::Bounded,
87+
),
88+
}
89+
}
90+
}
91+
92+
impl DisplayAs for ErrorExec {
93+
fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
94+
write!(f, "ErrorExec")
95+
}
96+
}
97+
98+
impl ExecutionPlan for ErrorExec {
99+
fn name(&self) -> &str {
100+
"ErrorExec"
101+
}
102+
103+
fn as_any(&self) -> &dyn Any {
104+
self
105+
}
106+
107+
fn properties(&self) -> &PlanProperties {
108+
&self.plan_properties
109+
}
110+
111+
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
112+
vec![]
113+
}
114+
115+
fn with_new_children(
116+
self: Arc<Self>,
117+
_: Vec<Arc<dyn ExecutionPlan>>,
118+
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
119+
Ok(self)
120+
}
121+
122+
fn execute(
123+
&self,
124+
_: usize,
125+
_: Arc<TaskContext>,
126+
) -> datafusion::common::Result<SendableRecordBatchStream> {
127+
Ok(Box::pin(RecordBatchStreamAdapter::new(
128+
self.schema(),
129+
stream::iter(vec![Err(DataFusionError::Execution(self.msg.clone()))]),
130+
)))
131+
}
132+
}
133+
134+
#[derive(Debug)]
135+
struct ErrorExecCodec;
136+
137+
#[derive(Clone, PartialEq, ::prost::Message)]
138+
struct ErrorExecProto {
139+
#[prost(string, tag = "1")]
140+
msg: String,
141+
}
142+
143+
impl PhysicalExtensionCodec for ErrorExecCodec {
144+
fn try_decode(
145+
&self,
146+
buf: &[u8],
147+
_: &[Arc<dyn ExecutionPlan>],
148+
_registry: &dyn FunctionRegistry,
149+
) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
150+
let node = ErrorExecProto::decode(buf).map_err(|err| proto_error(format!("{err}")))?;
151+
Ok(Arc::new(ErrorExec::new(&node.msg)))
152+
}
153+
154+
fn try_encode(
155+
&self,
156+
node: Arc<dyn ExecutionPlan>,
157+
buf: &mut Vec<u8>,
158+
) -> datafusion::common::Result<()> {
159+
let Some(plan) = node.as_any().downcast_ref::<ErrorExec>() else {
160+
return Err(proto_error(format!(
161+
"Expected plan to be of type ErrorExec, but was {}",
162+
node.name()
163+
)));
164+
};
165+
ErrorExecProto {
166+
msg: plan.msg.clone(),
167+
}
168+
.encode(buf)
169+
.map_err(|err| proto_error(format!("{err}")))
170+
}
171+
}
172+
}

0 commit comments

Comments
 (0)