diff --git a/Cargo.lock b/Cargo.lock index 979090d..9decb50 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1010,6 +1010,7 @@ dependencies = [ "futures", "http", "insta", + "object_store", "prost", "tokio", "tonic", diff --git a/Cargo.toml b/Cargo.toml index ad4bb14..6f62448 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ uuid = "1.17.0" delegate = "0.13.4" dashmap = "6.1.0" prost = "0.13.5" +object_store = "0.12.3" [dev-dependencies] -insta = { version = "1.43.1" , features = ["filters"]} \ No newline at end of file +insta = { version = "1.43.1", features = ["filters"] } \ No newline at end of file diff --git a/src/errors/arrow_error.rs b/src/errors/arrow_error.rs new file mode 100644 index 0000000..c810609 --- /dev/null +++ b/src/errors/arrow_error.rs @@ -0,0 +1,256 @@ +use crate::errors::io_error::IoErrorProto; +use datafusion::arrow::error::ArrowError; + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ArrowErrorProto { + #[prost(string, optional, tag = "1")] + pub ctx: Option, + #[prost( + oneof = "ArrowErrorInnerProto", + tags = "2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19" + )] + pub inner: Option, +} + +#[derive(Clone, PartialEq, prost::Oneof)] +pub enum ArrowErrorInnerProto { + #[prost(string, tag = "2")] + NotYetImplemented(String), + #[prost(string, tag = "3")] + ExternalError(String), + #[prost(string, tag = "4")] + CastError(String), + #[prost(string, tag = "5")] + MemoryError(String), + #[prost(string, tag = "6")] + ParseError(String), + #[prost(string, tag = "7")] + SchemaError(String), + #[prost(string, tag = "8")] + ComputeError(String), + #[prost(bool, tag = "9")] + DivideByZero(bool), + #[prost(string, tag = "10")] + ArithmeticOverflow(String), + #[prost(string, tag = "11")] + CsvError(String), + #[prost(string, tag = "12")] + JsonError(String), + #[prost(message, tag = "13")] + IoError(IoErrorProto), + #[prost(message, tag = "14")] + IpcError(String), + #[prost(message, tag = "15")] + InvalidArgumentError(String), + #[prost(message, tag = "16")] + ParquetError(String), + #[prost(message, tag = "17")] + CDataInterface(String), + #[prost(bool, tag = "18")] + DictionaryKeyOverflowError(bool), + #[prost(bool, tag = "19")] + RunEndIndexOverflowError(bool), +} + +impl ArrowErrorProto { + pub fn from_arrow_error(err: &ArrowError, ctx: Option<&String>) -> Self { + match err { + ArrowError::NotYetImplemented(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::NotYetImplemented(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::ExternalError(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::ExternalError(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::CastError(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::CastError(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::MemoryError(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::MemoryError(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::ParseError(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::ParseError(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::SchemaError(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::SchemaError(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::ComputeError(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::ComputeError(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::DivideByZero => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::DivideByZero(true)), + ctx: ctx.cloned(), + }, + ArrowError::ArithmeticOverflow(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::ArithmeticOverflow(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::CsvError(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::CsvError(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::JsonError(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::JsonError(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::IoError(msg, err) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::IoError(IoErrorProto::from_io_error( + msg, err, + ))), + ctx: ctx.cloned(), + }, + ArrowError::IpcError(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::IpcError(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::InvalidArgumentError(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::InvalidArgumentError(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::ParquetError(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::ParquetError(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::CDataInterface(msg) => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::CDataInterface(msg.to_string())), + ctx: ctx.cloned(), + }, + ArrowError::DictionaryKeyOverflowError => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::DictionaryKeyOverflowError(true)), + ctx: ctx.cloned(), + }, + ArrowError::RunEndIndexOverflowError => ArrowErrorProto { + inner: Some(ArrowErrorInnerProto::RunEndIndexOverflowError(true)), + ctx: ctx.cloned(), + }, + } + } + + pub fn to_arrow_error(&self) -> (ArrowError, Option) { + let Some(ref inner) = self.inner else { + return ( + ArrowError::ExternalError(Box::from("Malformed protobuf message".to_string())), + None, + ); + }; + let err = match inner { + ArrowErrorInnerProto::NotYetImplemented(msg) => { + ArrowError::NotYetImplemented(msg.to_string()) + } + ArrowErrorInnerProto::ExternalError(msg) => { + ArrowError::ExternalError(Box::from(msg.to_string())) + } + ArrowErrorInnerProto::CastError(msg) => ArrowError::CastError(msg.to_string()), + ArrowErrorInnerProto::MemoryError(msg) => ArrowError::MemoryError(msg.to_string()), + ArrowErrorInnerProto::ParseError(msg) => ArrowError::ParseError(msg.to_string()), + ArrowErrorInnerProto::SchemaError(msg) => ArrowError::SchemaError(msg.to_string()), + ArrowErrorInnerProto::ComputeError(msg) => ArrowError::ComputeError(msg.to_string()), + ArrowErrorInnerProto::DivideByZero(_) => ArrowError::DivideByZero, + ArrowErrorInnerProto::ArithmeticOverflow(msg) => { + ArrowError::ArithmeticOverflow(msg.to_string()) + } + ArrowErrorInnerProto::CsvError(msg) => ArrowError::CsvError(msg.to_string()), + ArrowErrorInnerProto::JsonError(msg) => ArrowError::JsonError(msg.to_string()), + ArrowErrorInnerProto::IoError(msg) => { + let (msg, err) = msg.to_io_error(); + ArrowError::IoError(err, msg) + } + ArrowErrorInnerProto::IpcError(msg) => ArrowError::IpcError(msg.to_string()), + ArrowErrorInnerProto::InvalidArgumentError(msg) => { + ArrowError::InvalidArgumentError(msg.to_string()) + } + ArrowErrorInnerProto::ParquetError(msg) => ArrowError::ParquetError(msg.to_string()), + ArrowErrorInnerProto::CDataInterface(msg) => { + ArrowError::CDataInterface(msg.to_string()) + } + ArrowErrorInnerProto::DictionaryKeyOverflowError(_) => { + ArrowError::DictionaryKeyOverflowError + } + ArrowErrorInnerProto::RunEndIndexOverflowError(_) => { + ArrowError::RunEndIndexOverflowError + } + }; + (err, self.ctx.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use prost::Message; + use std::io::{Error as IoError, ErrorKind}; + + #[test] + fn test_arrow_error_roundtrip() { + let test_cases = vec![ + ArrowError::NotYetImplemented("test not implemented".to_string()), + ArrowError::ExternalError(Box::new(std::io::Error::new( + ErrorKind::Other, + "external error", + ))), + ArrowError::CastError("cast error".to_string()), + ArrowError::MemoryError("memory error".to_string()), + ArrowError::ParseError("parse error".to_string()), + ArrowError::SchemaError("schema error".to_string()), + ArrowError::ComputeError("compute error".to_string()), + ArrowError::DivideByZero, + ArrowError::ArithmeticOverflow("overflow".to_string()), + ArrowError::CsvError("csv error".to_string()), + ArrowError::JsonError("json error".to_string()), + ArrowError::IoError( + "io message".to_string(), + IoError::new(ErrorKind::NotFound, "file not found"), + ), + ArrowError::IpcError("ipc error".to_string()), + ArrowError::InvalidArgumentError("invalid arg".to_string()), + ArrowError::ParquetError("parquet error".to_string()), + ArrowError::CDataInterface("cdata error".to_string()), + ArrowError::DictionaryKeyOverflowError, + ArrowError::RunEndIndexOverflowError, + ]; + + for original_error in test_cases { + let proto = ArrowErrorProto::from_arrow_error( + &original_error, + Some(&"test context".to_string()), + ); + let proto = ArrowErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap(); + let (recovered_error, recovered_ctx) = proto.to_arrow_error(); + + if original_error.to_string() != recovered_error.to_string() { + println!("original error: {}", original_error.to_string()); + println!("recovered error: {}", recovered_error.to_string()); + } + + assert_eq!(original_error.to_string(), recovered_error.to_string()); + assert_eq!(recovered_ctx, Some("test context".to_string())); + + let proto_no_ctx = ArrowErrorProto::from_arrow_error(&original_error, None); + let proto_no_ctx = + ArrowErrorProto::decode(proto_no_ctx.encode_to_vec().as_ref()).unwrap(); + let (recovered_error_no_ctx, recovered_ctx_no_ctx) = proto_no_ctx.to_arrow_error(); + + assert_eq!( + original_error.to_string(), + recovered_error_no_ctx.to_string() + ); + assert_eq!(recovered_ctx_no_ctx, None); + } + } + + #[test] + fn test_malformed_protobuf_message() { + let malformed_proto = ArrowErrorProto { + inner: None, + ctx: None, + }; + let (recovered_error, _) = malformed_proto.to_arrow_error(); + assert!(matches!(recovered_error, ArrowError::ExternalError(_))); + } +} diff --git a/src/errors/datafusion_error.rs b/src/errors/datafusion_error.rs new file mode 100644 index 0000000..8bd7934 --- /dev/null +++ b/src/errors/datafusion_error.rs @@ -0,0 +1,404 @@ +use crate::errors::arrow_error::ArrowErrorProto; +use crate::errors::io_error::IoErrorProto; +use crate::errors::objectstore_error::ObjectStoreErrorProto; +use crate::errors::parquet_error::ParquetErrorProto; +use crate::errors::parser_error::ParserErrorProto; +use crate::errors::schema_error::SchemaErrorProto; +use datafusion::common::{DataFusionError, Diagnostic}; +use datafusion::logical_expr::sqlparser::parser::ParserError; +use std::error::Error; +use std::sync::Arc; + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DataFusionErrorProto { + #[prost( + oneof = "DataFusionErrorInnerProto", + tags = "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19" + )] + pub inner: Option, +} + +#[derive(Clone, PartialEq, prost::Oneof)] +pub enum DataFusionErrorInnerProto { + #[prost(message, tag = "1")] + ArrowError(ArrowErrorProto), + #[prost(message, tag = "2")] + ParquetError(ParquetErrorProto), + #[prost(message, tag = "3")] + ObjectStoreError(ObjectStoreErrorProto), + #[prost(message, tag = "4")] + IoError(IoErrorProto), + #[prost(message, tag = "5")] + SQL(DataFusionSqlErrorProto), + #[prost(string, tag = "6")] + NotImplemented(String), + #[prost(string, tag = "7")] + Internal(String), + #[prost(string, tag = "8")] + Plan(String), + #[prost(string, tag = "9")] + Configuration(String), + #[prost(message, tag = "10")] + Schema(SchemaErrorProto), + #[prost(string, tag = "11")] + Execution(String), + #[prost(string, tag = "12")] + ExecutionJoin(String), + #[prost(string, tag = "13")] + ResourceExhausted(String), + #[prost(string, tag = "14")] + External(String), + #[prost(message, tag = "15")] + Context(DataFusionContextErrorProto), + #[prost(string, tag = "16")] + Substrait(String), + #[prost(message, boxed, tag = "17")] + Diagnostic(Box), + #[prost(message, tag = "18")] + Collection(DataFusionCollectionErrorProto), + #[prost(message, boxed, tag = "19")] + Shared(Box), +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DataFusionSqlErrorProto { + #[prost(message, tag = "1")] + err: Option, + #[prost(string, optional, tag = "2")] + backtrace: Option, +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DataFusionContextErrorProto { + #[prost(message, boxed, tag = "1")] + err: Option>, + #[prost(string, tag = "2")] + ctx: String, +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DataFusionCollectionErrorProto { + #[prost(message, repeated, boxed, tag = "1")] + errs: Vec>, +} + +impl DataFusionErrorProto { + pub fn from_datafusion_error(err: &DataFusionError) -> Self { + match err { + DataFusionError::ArrowError(err, msg) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::ArrowError( + ArrowErrorProto::from_arrow_error(err, msg.as_ref()), + )), + }, + DataFusionError::ParquetError(err) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::ParquetError( + ParquetErrorProto::from_parquet_error(err), + )), + }, + DataFusionError::ObjectStore(err) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::ObjectStoreError( + ObjectStoreErrorProto::from_object_store_error(err), + )), + }, + DataFusionError::IoError(err) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::IoError( + IoErrorProto::from_io_error("", err), + )), + }, + DataFusionError::SQL(err, msg) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::SQL(DataFusionSqlErrorProto { + err: Some(ParserErrorProto::from_parser_error(err)), + backtrace: msg.clone(), + })), + }, + DataFusionError::NotImplemented(msg) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::NotImplemented(msg.clone())), + }, + DataFusionError::Internal(msg) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::Internal(msg.clone())), + }, + DataFusionError::Plan(msg) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::Plan(msg.clone())), + }, + DataFusionError::Configuration(msg) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::Configuration(msg.clone())), + }, + DataFusionError::SchemaError(err, backtrace) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::Schema( + SchemaErrorProto::from_schema_error(err, backtrace.as_ref().as_ref()), + )), + }, + DataFusionError::Execution(msg) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::Execution(msg.clone())), + }, + DataFusionError::ExecutionJoin(err) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::ExecutionJoin(err.to_string())), + }, + DataFusionError::ResourcesExhausted(msg) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::ResourceExhausted(msg.clone())), + }, + DataFusionError::External(err) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::External(err.to_string())), + }, + DataFusionError::Context(ctx, err) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::Context( + DataFusionContextErrorProto { + ctx: ctx.to_string(), + err: Some(Box::new(DataFusionErrorProto::from_datafusion_error(err))), + }, + )), + }, + DataFusionError::Substrait(err) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::Substrait(err.to_string())), + }, + // Diagnostics are trimmed out + DataFusionError::Diagnostic(_, err) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::Diagnostic(Box::new( + DataFusionErrorProto::from_datafusion_error(err), + ))), + }, + DataFusionError::Collection(errs) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::Collection( + DataFusionCollectionErrorProto { + errs: errs + .iter() + .map(DataFusionErrorProto::from_datafusion_error) + .map(Box::new) + .collect(), + }, + )), + }, + DataFusionError::Shared(err) => DataFusionErrorProto { + inner: Some(DataFusionErrorInnerProto::Shared(Box::new( + DataFusionErrorProto::from_datafusion_error(err.as_ref()), + ))), + }, + } + } + + pub fn to_datafusion_err(&self) -> DataFusionError { + let Some(ref inner) = self.inner else { + return DataFusionError::Internal("DataFusionError proto message is empty".to_string()); + }; + + match inner { + DataFusionErrorInnerProto::ArrowError(err) => { + let (err, ctx) = err.to_arrow_error(); + DataFusionError::ArrowError(err, ctx) + } + DataFusionErrorInnerProto::ParquetError(err) => { + DataFusionError::ParquetError(err.to_parquet_error()) + } + DataFusionErrorInnerProto::ObjectStoreError(err) => { + DataFusionError::ObjectStore(err.to_object_store_error()) + } + DataFusionErrorInnerProto::IoError(err) => { + let (err, _) = err.to_io_error(); + DataFusionError::IoError(err) + } + DataFusionErrorInnerProto::SQL(err) => { + let backtrace = err.backtrace.clone(); + let err = err.err.as_ref().map(|err| err.to_parser_error()); + let err = err.unwrap_or(ParserError::ParserError("".to_string())); + DataFusionError::SQL(err, backtrace) + } + DataFusionErrorInnerProto::NotImplemented(msg) => { + DataFusionError::NotImplemented(msg.clone()) + } + DataFusionErrorInnerProto::Internal(msg) => DataFusionError::Internal(msg.clone()), + DataFusionErrorInnerProto::Plan(msg) => DataFusionError::Plan(msg.clone()), + DataFusionErrorInnerProto::Configuration(msg) => { + DataFusionError::Configuration(msg.clone()) + } + DataFusionErrorInnerProto::Schema(err) => { + let (err, backtrace) = err.to_schema_error(); + DataFusionError::SchemaError(err, Box::new(backtrace)) + } + DataFusionErrorInnerProto::Execution(msg) => DataFusionError::Execution(msg.clone()), + // We cannot build JoinErrors ourselves, so instead we map it to internal. + DataFusionErrorInnerProto::ExecutionJoin(msg) => DataFusionError::Internal(msg.clone()), + DataFusionErrorInnerProto::ResourceExhausted(msg) => { + DataFusionError::ResourcesExhausted(msg.clone()) + } + DataFusionErrorInnerProto::External(generic) => { + DataFusionError::External(Box::new(DistributedDataFusionGenericError { + message: generic.clone(), + })) + } + DataFusionErrorInnerProto::Context(err) => DataFusionError::Context( + err.ctx.clone(), + Box::new(err.err.as_ref().map(|v| v.to_datafusion_err()).unwrap_or( + DataFusionError::Internal( + "Missing DataFusionError protobuf message".to_string(), + ), + )), + ), + DataFusionErrorInnerProto::Substrait(msg) => DataFusionError::Substrait(msg.clone()), + DataFusionErrorInnerProto::Diagnostic(err) => { + DataFusionError::Diagnostic( + // We lose diagnostic information because we are not encoding it. + Box::new(Diagnostic::new_error("", None)), + Box::new(err.to_datafusion_err()), + ) + } + DataFusionErrorInnerProto::Collection(errs) => DataFusionError::Collection( + errs.errs + .iter() + .map(|err| err.to_datafusion_err()) + .collect(), + ), + DataFusionErrorInnerProto::Shared(err) => { + DataFusionError::Shared(Arc::new(err.to_datafusion_err())) + } + } + } +} + +#[derive(Clone, Debug)] +pub struct DistributedDataFusionGenericError { + pub message: String, +} + +impl std::fmt::Display for DistributedDataFusionGenericError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +impl Error for DistributedDataFusionGenericError {} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::error::ArrowError; + use datafusion::common::{DataFusionError, SchemaError}; + use datafusion::logical_expr::sqlparser::parser::ParserError; + use datafusion::parquet::errors::ParquetError; + use object_store::Error as ObjectStoreError; + use prost::Message; + use std::io::{Error as IoError, ErrorKind}; + use std::sync::Arc; + + #[test] + fn test_datafusion_error_roundtrip() { + let test_cases = vec![ + DataFusionError::ArrowError( + ArrowError::ComputeError("compute".to_string()), + Some("arrow context".to_string()), + ), + DataFusionError::ParquetError(ParquetError::General("parquet error".to_string())), + DataFusionError::ObjectStore(ObjectStoreError::NotFound { + path: "test/path".to_string(), + source: Box::new(std::io::Error::new(ErrorKind::NotFound, "not found")), + }), + DataFusionError::IoError(IoError::new( + ErrorKind::PermissionDenied, + "permission denied", + )), + DataFusionError::SQL( + ParserError::ParserError("sql parse error".to_string()), + Some("sql backtrace".to_string()), + ), + DataFusionError::NotImplemented("not implemented".to_string()), + DataFusionError::Internal("internal error".to_string()), + DataFusionError::Plan("plan error".to_string()), + DataFusionError::Configuration("config error".to_string()), + DataFusionError::SchemaError( + SchemaError::AmbiguousReference { + field: datafusion::common::Column::new_unqualified("test_field"), + }, + Box::new(None), + ), + DataFusionError::Execution("execution error".to_string()), + DataFusionError::ResourcesExhausted("resources exhausted".to_string()), + DataFusionError::External(Box::new(std::io::Error::new(ErrorKind::Other, "external"))), + DataFusionError::Context( + "context message".to_string(), + Box::new(DataFusionError::Internal("nested".to_string())), + ), + DataFusionError::Substrait("substrait error".to_string()), + DataFusionError::Collection(vec![ + DataFusionError::Internal("error 1".to_string()), + DataFusionError::Internal("error 2".to_string()), + ]), + DataFusionError::Shared(Arc::new(DataFusionError::Internal( + "shared error".to_string(), + ))), + ]; + + for original_error in test_cases { + let proto = DataFusionErrorProto::from_datafusion_error(&original_error); + let proto = DataFusionErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap(); + let recovered_error = proto.to_datafusion_err(); + + assert_eq!(original_error.to_string(), recovered_error.to_string()); + } + } + + #[test] + fn test_malformed_protobuf_message() { + let malformed_proto = DataFusionErrorProto { inner: None }; + let recovered_error = malformed_proto.to_datafusion_err(); + assert!(matches!(recovered_error, DataFusionError::Internal(_))); + } + + #[test] + fn test_nested_datafusion_errors() { + let nested_error = DataFusionError::Context( + "outer context".to_string(), + Box::new(DataFusionError::Context( + "inner context".to_string(), + Box::new(DataFusionError::Internal("deepest error".to_string())), + )), + ); + + let proto = DataFusionErrorProto::from_datafusion_error(&nested_error); + let proto = DataFusionErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap(); + let recovered_error = proto.to_datafusion_err(); + + assert_eq!(nested_error.to_string(), recovered_error.to_string()); + } + + #[test] + fn test_collection_errors() { + let collection_error = DataFusionError::Collection(vec![ + DataFusionError::Internal("error 1".to_string()), + DataFusionError::Plan("error 2".to_string()), + DataFusionError::Execution("error 3".to_string()), + ]); + + let proto = DataFusionErrorProto::from_datafusion_error(&collection_error); + let proto = DataFusionErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap(); + let recovered_error = proto.to_datafusion_err(); + + assert_eq!(collection_error.to_string(), recovered_error.to_string()); + } + + #[test] + fn test_sql_error_with_backtrace() { + let sql_error = DataFusionError::SQL( + ParserError::ParserError("syntax error".to_string()), + Some("test backtrace".to_string()), + ); + + let proto = DataFusionErrorProto::from_datafusion_error(&sql_error); + let proto = DataFusionErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap(); + let recovered_error = proto.to_datafusion_err(); + + if let DataFusionError::SQL(_, backtrace) = recovered_error { + assert_eq!(backtrace, Some("test backtrace".to_string())); + } else { + panic!("Expected SQL error"); + } + } + + #[test] + fn test_distributed_generic_error() { + let generic_error = DistributedDataFusionGenericError { + message: "test message".to_string(), + }; + + assert_eq!(generic_error.to_string(), "test message"); + assert!(Error::source(&generic_error).is_none()); + } +} diff --git a/src/errors/io_error.rs b/src/errors/io_error.rs new file mode 100644 index 0000000..ff5d271 --- /dev/null +++ b/src/errors/io_error.rs @@ -0,0 +1,176 @@ +use std::io::ErrorKind; + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct IoErrorProto { + #[prost(string, tag = "1")] + pub msg: String, + #[prost(int32, tag = "2")] + pub code: i32, + #[prost(string, tag = "3")] + pub err: String, +} + +impl IoErrorProto { + pub(crate) fn from_io_error(msg: &str, err: &std::io::Error) -> Self { + Self { + msg: msg.to_string(), + code: match err.kind() { + ErrorKind::NotFound => 0, + ErrorKind::PermissionDenied => 1, + ErrorKind::ConnectionRefused => 2, + ErrorKind::ConnectionReset => 3, + ErrorKind::HostUnreachable => 4, + ErrorKind::NetworkUnreachable => 5, + ErrorKind::ConnectionAborted => 6, + ErrorKind::NotConnected => 7, + ErrorKind::AddrInUse => 8, + ErrorKind::AddrNotAvailable => 9, + ErrorKind::NetworkDown => 10, + ErrorKind::BrokenPipe => 11, + ErrorKind::AlreadyExists => 12, + ErrorKind::WouldBlock => 13, + ErrorKind::NotADirectory => 14, + ErrorKind::IsADirectory => 15, + ErrorKind::DirectoryNotEmpty => 16, + ErrorKind::ReadOnlyFilesystem => 17, + ErrorKind::StaleNetworkFileHandle => 18, + ErrorKind::InvalidInput => 19, + ErrorKind::InvalidData => 20, + ErrorKind::TimedOut => 21, + ErrorKind::WriteZero => 22, + ErrorKind::StorageFull => 23, + ErrorKind::NotSeekable => 24, + ErrorKind::QuotaExceeded => 25, + ErrorKind::FileTooLarge => 26, + ErrorKind::ResourceBusy => 27, + ErrorKind::ExecutableFileBusy => 28, + ErrorKind::Deadlock => 29, + ErrorKind::CrossesDevices => 30, + ErrorKind::TooManyLinks => 31, + ErrorKind::ArgumentListTooLong => 32, + ErrorKind::Interrupted => 33, + ErrorKind::Unsupported => 34, + ErrorKind::UnexpectedEof => 35, + ErrorKind::OutOfMemory => 36, + ErrorKind::Other => 37, + _ => -1, + }, + err: err.to_string(), + } + } + + pub(crate) fn to_io_error(&self) -> (std::io::Error, String) { + let kind = match self.code { + 0 => ErrorKind::NotFound, + 1 => ErrorKind::PermissionDenied, + 2 => ErrorKind::ConnectionRefused, + 3 => ErrorKind::ConnectionReset, + 4 => ErrorKind::HostUnreachable, + 5 => ErrorKind::NetworkUnreachable, + 6 => ErrorKind::ConnectionAborted, + 7 => ErrorKind::NotConnected, + 8 => ErrorKind::AddrInUse, + 9 => ErrorKind::AddrNotAvailable, + 10 => ErrorKind::NetworkDown, + 11 => ErrorKind::BrokenPipe, + 12 => ErrorKind::AlreadyExists, + 13 => ErrorKind::WouldBlock, + 14 => ErrorKind::NotADirectory, + 15 => ErrorKind::IsADirectory, + 16 => ErrorKind::DirectoryNotEmpty, + 17 => ErrorKind::ReadOnlyFilesystem, + 18 => ErrorKind::StaleNetworkFileHandle, + 19 => ErrorKind::InvalidInput, + 20 => ErrorKind::InvalidData, + 21 => ErrorKind::TimedOut, + 22 => ErrorKind::WriteZero, + 23 => ErrorKind::StorageFull, + 24 => ErrorKind::NotSeekable, + 25 => ErrorKind::QuotaExceeded, + 26 => ErrorKind::FileTooLarge, + 27 => ErrorKind::ResourceBusy, + 28 => ErrorKind::ExecutableFileBusy, + 29 => ErrorKind::Deadlock, + 30 => ErrorKind::CrossesDevices, + 31 => ErrorKind::TooManyLinks, + 32 => ErrorKind::ArgumentListTooLong, + 33 => ErrorKind::Interrupted, + 34 => ErrorKind::Unsupported, + 35 => ErrorKind::UnexpectedEof, + 36 => ErrorKind::OutOfMemory, + 37 => ErrorKind::Other, + _ => ErrorKind::Other, + }; + ( + std::io::Error::new(kind, self.err.clone()), + self.msg.clone(), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use prost::Message; + use std::io::{Error as IoError, ErrorKind}; + + #[test] + fn test_io_error_roundtrip() { + let test_cases = vec![ + (ErrorKind::NotFound, "file not found"), + (ErrorKind::PermissionDenied, "permission denied"), + (ErrorKind::ConnectionRefused, "connection refused"), + (ErrorKind::ConnectionReset, "connection reset"), + (ErrorKind::ConnectionAborted, "connection aborted"), + (ErrorKind::NotConnected, "not connected"), + (ErrorKind::AddrInUse, "address in use"), + (ErrorKind::AddrNotAvailable, "address not available"), + (ErrorKind::BrokenPipe, "broken pipe"), + (ErrorKind::AlreadyExists, "already exists"), + (ErrorKind::WouldBlock, "would block"), + (ErrorKind::InvalidInput, "invalid input"), + (ErrorKind::InvalidData, "invalid data"), + (ErrorKind::TimedOut, "timed out"), + (ErrorKind::WriteZero, "write zero"), + (ErrorKind::Interrupted, "interrupted"), + (ErrorKind::UnexpectedEof, "unexpected eof"), + (ErrorKind::Other, "other error"), + ]; + + for (kind, msg) in test_cases { + let original_error = IoError::new(kind, msg); + let proto = IoErrorProto::from_io_error("test message", &original_error); + let proto = IoErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap(); + let (recovered_error, recovered_message) = proto.to_io_error(); + + assert_eq!(original_error.kind(), recovered_error.kind()); + assert_eq!(original_error.to_string(), recovered_error.to_string()); + assert_eq!(recovered_message, "test message"); + } + } + + #[test] + fn test_protobuf_serialization() { + let original_error = IoError::new(ErrorKind::NotFound, "file not found"); + let proto = IoErrorProto::from_io_error("test message", &original_error); + let proto = IoErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap(); + let (recovered_error, recovered_message) = proto.to_io_error(); + + assert_eq!(original_error.kind(), recovered_error.kind()); + assert_eq!(original_error.to_string(), recovered_error.to_string()); + assert_eq!(recovered_message, "test message"); + } + + #[test] + fn test_unknown_error_kind() { + let proto = IoErrorProto { + msg: "test message".to_string(), + code: -1, + err: "unknown error".to_string(), + }; + let (recovered_error, recovered_message) = proto.to_io_error(); + + assert_eq!(recovered_error.kind(), ErrorKind::Other); + assert_eq!(recovered_message, "test message"); + } +} diff --git a/src/errors/mod.rs b/src/errors/mod.rs new file mode 100644 index 0000000..172aa86 --- /dev/null +++ b/src/errors/mod.rs @@ -0,0 +1,48 @@ +use crate::errors::datafusion_error::DataFusionErrorProto; +use datafusion::common::internal_datafusion_err; +use datafusion::error::DataFusionError; +use prost::Message; + +mod arrow_error; +mod datafusion_error; +mod io_error; +mod objectstore_error; +mod parquet_error; +mod parser_error; +mod schema_error; + +/// Encodes a [DataFusionError] into a [tonic::Status] error. The produced error is suitable +/// to be sent over the wire and decoded by the receiving end, recovering the original +/// [DataFusionError] across a network boundary with [tonic_status_to_datafusion_error]. +pub fn datafusion_error_to_tonic_status(err: &DataFusionError) -> tonic::Status { + let err = DataFusionErrorProto::from_datafusion_error(err); + let err = err.encode_to_vec(); + let status = tonic::Status::with_details(tonic::Code::Internal, "DataFusionError", err.into()); + status +} + +/// Decodes a [DataFusionError] from a [tonic::Status] error. If the provided [tonic::Status] +/// error was produced with [datafusion_error_to_tonic_status], this function will be able to +/// recover it even across a network boundary. +/// +/// The provided [tonic::Status] error might also be something else, like an actual network +/// failure. This function returns `None` for those cases. +pub fn tonic_status_to_datafusion_error(status: &tonic::Status) -> Option { + if status.code() != tonic::Code::Internal { + return None; + } + + if status.message() != "DataFusionError" { + return None; + } + + match DataFusionErrorProto::decode(status.details()) { + Ok(err_proto) => { + dbg!(&err_proto); + Some(err_proto.to_datafusion_err()) + } + Err(err) => Some(internal_datafusion_err!( + "Cannot decode DataFusionError: {err}" + )), + } +} diff --git a/src/errors/objectstore_error.rs b/src/errors/objectstore_error.rs new file mode 100644 index 0000000..45d0ff7 --- /dev/null +++ b/src/errors/objectstore_error.rs @@ -0,0 +1,324 @@ +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ObjectStoreErrorProto { + #[prost( + oneof = "ObjectStoreErrorInnerProto", + tags = "1,2,3,4,5,6,7,8,9,10,11,12" + )] + pub inner: Option, +} + +#[derive(Clone, PartialEq, prost::Oneof)] +pub enum ObjectStoreErrorInnerProto { + #[prost(message, tag = "1")] + Generic(ObjectStoreGenericErrorProto), + #[prost(message, tag = "2")] + NotFound(ObjectStoreSourcePathErrorProto), + #[prost(message, tag = "3")] + InvalidPath(ObjectStoreSourceErrorProto), + #[prost(message, tag = "4")] + JoinError(ObjectStoreSourceErrorProto), + #[prost(message, tag = "5")] + NotSupported(ObjectStoreSourceErrorProto), + #[prost(message, tag = "6")] + AlreadyExists(ObjectStoreSourcePathErrorProto), + #[prost(message, tag = "7")] + Precondition(ObjectStoreSourcePathErrorProto), + #[prost(message, tag = "8")] + NotModified(ObjectStoreSourcePathErrorProto), + #[prost(message, tag = "9")] + NotImplemented(bool), + #[prost(message, tag = "10")] + PermissionDenied(ObjectStoreSourcePathErrorProto), + #[prost(message, tag = "11")] + Unauthenticated(ObjectStoreSourcePathErrorProto), + #[prost(message, tag = "12")] + UnknownConfigurationKey(ObjectStoreConfigurationKeyErrorProto), +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ObjectStoreGenericErrorProto { + #[prost(string, tag = "1")] + store: String, + #[prost(string, tag = "2")] + source: String, +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ObjectStoreSourceErrorProto { + #[prost(string, tag = "1")] + source: String, +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ObjectStoreSourcePathErrorProto { + #[prost(string, tag = "1")] + path: String, + #[prost(string, tag = "2")] + source: String, +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ObjectStoreConfigurationKeyErrorProto { + #[prost(string, tag = "1")] + key: String, + #[prost(string, tag = "2")] + store: String, +} + +impl ObjectStoreErrorProto { + pub fn from_object_store_error(err: &object_store::Error) -> Self { + match err { + object_store::Error::Generic { store, source } => ObjectStoreErrorProto { + inner: Some(ObjectStoreErrorInnerProto::Generic( + ObjectStoreGenericErrorProto { + store: store.to_string(), + source: source.to_string(), + }, + )), + }, + object_store::Error::NotFound { path, source } => ObjectStoreErrorProto { + inner: Some(ObjectStoreErrorInnerProto::NotFound( + ObjectStoreSourcePathErrorProto { + path: path.to_string(), + source: source.to_string(), + }, + )), + }, + object_store::Error::InvalidPath { source } => ObjectStoreErrorProto { + inner: Some(ObjectStoreErrorInnerProto::InvalidPath( + ObjectStoreSourceErrorProto { + source: source.to_string(), + }, + )), + }, + object_store::Error::JoinError { source } => ObjectStoreErrorProto { + inner: Some(ObjectStoreErrorInnerProto::JoinError( + ObjectStoreSourceErrorProto { + source: source.to_string(), + }, + )), + }, + object_store::Error::NotSupported { source } => ObjectStoreErrorProto { + inner: Some(ObjectStoreErrorInnerProto::NotSupported( + ObjectStoreSourceErrorProto { + source: source.to_string(), + }, + )), + }, + object_store::Error::AlreadyExists { path, source } => ObjectStoreErrorProto { + inner: Some(ObjectStoreErrorInnerProto::AlreadyExists( + ObjectStoreSourcePathErrorProto { + path: path.to_string(), + source: source.to_string(), + }, + )), + }, + object_store::Error::Precondition { path, source } => ObjectStoreErrorProto { + inner: Some(ObjectStoreErrorInnerProto::Precondition( + ObjectStoreSourcePathErrorProto { + path: path.to_string(), + source: source.to_string(), + }, + )), + }, + object_store::Error::NotModified { path, source } => ObjectStoreErrorProto { + inner: Some(ObjectStoreErrorInnerProto::NotModified( + ObjectStoreSourcePathErrorProto { + path: path.to_string(), + source: source.to_string(), + }, + )), + }, + object_store::Error::NotImplemented => ObjectStoreErrorProto { + inner: Some(ObjectStoreErrorInnerProto::NotImplemented(true)), + }, + object_store::Error::PermissionDenied { path, source } => ObjectStoreErrorProto { + inner: Some(ObjectStoreErrorInnerProto::PermissionDenied( + ObjectStoreSourcePathErrorProto { + path: path.to_string(), + source: source.to_string(), + }, + )), + }, + object_store::Error::Unauthenticated { path, source } => ObjectStoreErrorProto { + inner: Some(ObjectStoreErrorInnerProto::Unauthenticated( + ObjectStoreSourcePathErrorProto { + path: path.to_string(), + source: source.to_string(), + }, + )), + }, + object_store::Error::UnknownConfigurationKey { key, store } => ObjectStoreErrorProto { + inner: Some(ObjectStoreErrorInnerProto::UnknownConfigurationKey( + ObjectStoreConfigurationKeyErrorProto { + key: key.to_string(), + store: store.to_string(), + }, + )), + }, + _ => ObjectStoreErrorProto { + inner: Some(ObjectStoreErrorInnerProto::Generic( + ObjectStoreGenericErrorProto { + store: "Could not serialize ObjectStore error to proto".to_string(), + source: "Could not serialize ObjectStore error to proto".to_string(), + }, + )), + }, + } + } + + pub fn to_object_store_error(&self) -> object_store::Error { + let Some(ref inner) = self.inner else { + return object_store::Error::Generic { + store: "unknown", + source: "Could not deserialize ObjectStore error from proto".into(), + }; + }; + + match inner { + ObjectStoreErrorInnerProto::Generic(msg) => object_store::Error::Generic { + store: parse_store(&msg.store), + source: msg.source.clone().into(), + }, + ObjectStoreErrorInnerProto::NotFound(msg) => object_store::Error::NotFound { + path: msg.path.clone(), + source: msg.source.clone().into(), + }, + ObjectStoreErrorInnerProto::InvalidPath(msg) => object_store::Error::Generic { + // InvalidPath contains a full nested error, and my time has been wasted too + // much with this already + store: "unknown", + source: format!("InvalidPath: {}", msg.source).into(), + }, + ObjectStoreErrorInnerProto::JoinError(msg) => object_store::Error::Generic { + // tokio::task::JoinError does not allow to be built + store: "unknown", + source: format!("JoinError: {}", msg.source).into(), + }, + ObjectStoreErrorInnerProto::NotSupported(msg) => object_store::Error::NotSupported { + source: msg.source.clone().into(), + }, + ObjectStoreErrorInnerProto::AlreadyExists(msg) => object_store::Error::AlreadyExists { + path: msg.path.clone(), + source: msg.source.clone().into(), + }, + ObjectStoreErrorInnerProto::Precondition(msg) => object_store::Error::Precondition { + path: msg.path.clone(), + source: msg.source.clone().into(), + }, + ObjectStoreErrorInnerProto::NotModified(msg) => object_store::Error::NotModified { + path: msg.path.clone(), + source: msg.source.clone().into(), + }, + ObjectStoreErrorInnerProto::NotImplemented(_) => object_store::Error::NotImplemented, + ObjectStoreErrorInnerProto::PermissionDenied(msg) => { + object_store::Error::PermissionDenied { + path: msg.path.clone(), + source: msg.source.clone().into(), + } + } + ObjectStoreErrorInnerProto::Unauthenticated(msg) => { + object_store::Error::Unauthenticated { + path: msg.path.clone(), + source: msg.source.clone().into(), + } + } + ObjectStoreErrorInnerProto::UnknownConfigurationKey(msg) => { + object_store::Error::UnknownConfigurationKey { + key: msg.key.clone(), + store: parse_store(&msg.store), + } + } + } + } +} + +fn parse_store(store: &str) -> &'static str { + // some appearances while looking at + // https://github.com/search?q=repo%3Aapache%2Farrow-rs-object-store%20store%3A%20%22&type=code + match store { + "GCS" => "GCS", + "MicrosoftAzure" => "MicrosoftAzure", + "S3" => "S3", + "Config" => "Config", + "ChunkedStore" => "ChunkedStore", + "LineDelimiter" => "LineDelimiter", + "HTTP client" => "HTTP client", + "HTTP" => "HTTP", + "URL" => "URL", + "InMemory" => "InMemory", + "ObjectStoreRegistry" => "ObjectStoreRegistry", + "Parts" => "Parts", + "LocalFileSystem" => "LocalFileSystem", + _ => "Unknown", + } +} + +#[cfg(test)] +mod tests { + use super::*; + use object_store::Error as ObjectStoreError; + use prost::Message; + use std::io::ErrorKind; + + #[test] + fn test_object_store_error_roundtrip() { + let test_cases = vec![ + // Use known store names that will be preserved + ObjectStoreError::Generic { + store: "S3", + source: Box::new(std::io::Error::new(ErrorKind::Other, "generic error")), + }, + ObjectStoreError::NotFound { + path: "test/path".to_string(), + source: Box::new(std::io::Error::new(ErrorKind::NotFound, "not found")), + }, + ObjectStoreError::AlreadyExists { + path: "existing/path".to_string(), + source: Box::new(std::io::Error::new(ErrorKind::AlreadyExists, "exists")), + }, + ObjectStoreError::Precondition { + path: "precondition/path".to_string(), + source: Box::new(std::io::Error::new(ErrorKind::Other, "precondition failed")), + }, + ObjectStoreError::NotSupported { + source: Box::new(std::io::Error::new(ErrorKind::Unsupported, "not supported")), + }, + ObjectStoreError::NotModified { + path: "not/modified".to_string(), + source: Box::new(std::io::Error::new(ErrorKind::Other, "not modified")), + }, + ObjectStoreError::NotImplemented, + ObjectStoreError::PermissionDenied { + path: "denied/path".to_string(), + source: Box::new(std::io::Error::new( + ErrorKind::PermissionDenied, + "permission denied", + )), + }, + ObjectStoreError::Unauthenticated { + path: "auth/path".to_string(), + source: Box::new(std::io::Error::new(ErrorKind::Other, "unauthenticated")), + }, + ObjectStoreError::UnknownConfigurationKey { + key: "unknown_key".to_string(), + store: "S3", + }, + ]; + + for original_error in test_cases { + let proto = ObjectStoreErrorProto::from_object_store_error(&original_error); + let proto = ObjectStoreErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap(); + let recovered_error = proto.to_object_store_error(); + + assert_eq!(original_error.to_string(), recovered_error.to_string()); + } + } + + #[test] + fn test_malformed_protobuf_message() { + let malformed_proto = ObjectStoreErrorProto { inner: None }; + let recovered_error = malformed_proto.to_object_store_error(); + assert!(matches!(recovered_error, ObjectStoreError::Generic { .. })); + } +} diff --git a/src/errors/parquet_error.rs b/src/errors/parquet_error.rs new file mode 100644 index 0000000..991e68e --- /dev/null +++ b/src/errors/parquet_error.rs @@ -0,0 +1,129 @@ +use datafusion::parquet::errors::ParquetError; + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetErrorProto { + #[prost(oneof = "ParquetErrorInnerProto", tags = "1,2,3,4,5,6,7")] + pub inner: Option, +} + +#[derive(Clone, PartialEq, prost::Oneof)] +pub enum ParquetErrorInnerProto { + #[prost(message, tag = "1")] + General(String), + #[prost(message, tag = "2")] + NYI(String), + #[prost(message, tag = "3")] + EOF(String), + #[prost(message, tag = "4")] + ArrowError(String), + #[prost(message, tag = "5")] + IndexOutOfBound(IndexOutOfBoundProto), + #[prost(message, tag = "6")] + External(String), + #[prost(uint64, tag = "7")] + NeedMoreData(u64), +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct IndexOutOfBoundProto { + #[prost(uint64, tag = "1")] + a: u64, + #[prost(uint64, tag = "2")] + b: u64, +} + +impl ParquetErrorProto { + pub fn from_parquet_error(err: &ParquetError) -> Self { + match err { + ParquetError::General(msg) => ParquetErrorProto { + inner: Some(ParquetErrorInnerProto::General(msg.to_string())), + }, + ParquetError::NYI(msg) => ParquetErrorProto { + inner: Some(ParquetErrorInnerProto::NYI(msg.to_string())), + }, + ParquetError::EOF(msg) => ParquetErrorProto { + inner: Some(ParquetErrorInnerProto::EOF(msg.to_string())), + }, + ParquetError::ArrowError(msg) => ParquetErrorProto { + inner: Some(ParquetErrorInnerProto::ArrowError(msg.to_string())), + }, + ParquetError::IndexOutOfBound(a, b) => ParquetErrorProto { + inner: Some(ParquetErrorInnerProto::IndexOutOfBound( + IndexOutOfBoundProto { + a: *a as u64, + b: *b as u64, + }, + )), + }, + ParquetError::External(err) => ParquetErrorProto { + inner: Some(ParquetErrorInnerProto::External(err.to_string())), + }, + ParquetError::NeedMoreData(a) => ParquetErrorProto { + inner: Some(ParquetErrorInnerProto::NeedMoreData(*a as u64)), + }, + _ => ParquetErrorProto { + inner: Some(ParquetErrorInnerProto::General( + "ParquetError could not be serialized into protobuf".to_string(), + )), + }, + } + } + + pub fn to_parquet_error(&self) -> ParquetError { + let Some(ref inner) = self.inner else { + return ParquetError::External(Box::from("Malformed protobuf message".to_string())); + }; + + match inner { + ParquetErrorInnerProto::General(msg) => ParquetError::General(msg.to_string()), + ParquetErrorInnerProto::NYI(msg) => ParquetError::NYI(msg.to_string()), + ParquetErrorInnerProto::EOF(msg) => ParquetError::EOF(msg.to_string()), + ParquetErrorInnerProto::ArrowError(msg) => ParquetError::ArrowError(msg.to_string()), + ParquetErrorInnerProto::IndexOutOfBound(IndexOutOfBoundProto { a, b }) => { + ParquetError::IndexOutOfBound(*a as usize, *b as usize) + } + ParquetErrorInnerProto::External(msg) => { + ParquetError::External(Box::from(msg.to_string())) + } + ParquetErrorInnerProto::NeedMoreData(n) => ParquetError::NeedMoreData(*n as usize), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::parquet::errors::ParquetError; + use prost::Message; + + #[test] + fn test_parquet_error_roundtrip() { + let test_cases = vec![ + ParquetError::General("general error".to_string()), + ParquetError::NYI("not yet implemented".to_string()), + ParquetError::EOF("end of file".to_string()), + ParquetError::ArrowError("arrow error".to_string()), + ParquetError::IndexOutOfBound(42, 100), + ParquetError::External(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "external error", + ))), + ParquetError::NeedMoreData(1024), + ]; + + for original_error in test_cases { + let proto = ParquetErrorProto::from_parquet_error(&original_error); + let proto = ParquetErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap(); + let recovered_error = proto.to_parquet_error(); + + assert_eq!(original_error.to_string(), recovered_error.to_string()); + } + } + + #[test] + fn test_malformed_protobuf_message() { + let malformed_proto = ParquetErrorProto { inner: None }; + let recovered_error = malformed_proto.to_parquet_error(); + assert!(matches!(recovered_error, ParquetError::External(_))); + } +} diff --git a/src/errors/parser_error.rs b/src/errors/parser_error.rs new file mode 100644 index 0000000..b42a911 --- /dev/null +++ b/src/errors/parser_error.rs @@ -0,0 +1,78 @@ +use datafusion::sql::sqlparser::parser::ParserError; + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParserErrorProto { + #[prost(oneof = "ParserErrorInnerProto", tags = "1,2,3")] + pub inner: Option, +} + +#[derive(Clone, PartialEq, prost::Oneof)] +pub enum ParserErrorInnerProto { + #[prost(string, tag = "1")] + TokenizerError(String), + #[prost(string, tag = "2")] + ParserError(String), + #[prost(bool, tag = "3")] + RecursionLimitExceeded(bool), +} + +impl ParserErrorProto { + pub fn from_parser_error(err: &ParserError) -> Self { + match err { + ParserError::TokenizerError(msg) => ParserErrorProto { + inner: Some(ParserErrorInnerProto::TokenizerError(msg.to_string())), + }, + ParserError::ParserError(msg) => ParserErrorProto { + inner: Some(ParserErrorInnerProto::ParserError(msg.to_string())), + }, + ParserError::RecursionLimitExceeded => ParserErrorProto { + inner: Some(ParserErrorInnerProto::RecursionLimitExceeded(true)), + }, + } + } + + pub fn to_parser_error(&self) -> ParserError { + let Some(ref inner) = self.inner else { + return ParserError::ParserError("Malformed protobuf message".to_string()); + }; + + match inner { + ParserErrorInnerProto::TokenizerError(msg) => { + ParserError::TokenizerError(msg.to_string()) + } + ParserErrorInnerProto::ParserError(msg) => ParserError::ParserError(msg.to_string()), + ParserErrorInnerProto::RecursionLimitExceeded(_) => ParserError::RecursionLimitExceeded, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::sql::sqlparser::parser::ParserError; + use prost::Message; + + #[test] + fn test_parser_error_roundtrip() { + let test_cases = vec![ + ParserError::ParserError("syntax error".to_string()), + ParserError::TokenizerError("tokenizer error".to_string()), + ParserError::RecursionLimitExceeded, + ]; + + for original_error in test_cases { + let proto = ParserErrorProto::from_parser_error(&original_error); + let proto = ParserErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap(); + let recovered_error = proto.to_parser_error(); + + assert_eq!(original_error.to_string(), recovered_error.to_string()); + } + } + + #[test] + fn test_malformed_protobuf_message() { + let malformed_proto = ParserErrorProto { inner: None }; + let recovered_error = malformed_proto.to_parser_error(); + assert!(matches!(recovered_error, ParserError::ParserError(_))); + } +} diff --git a/src/errors/schema_error.rs b/src/errors/schema_error.rs new file mode 100644 index 0000000..925f97f --- /dev/null +++ b/src/errors/schema_error.rs @@ -0,0 +1,362 @@ +use datafusion::common::{Column, SchemaError, TableReference}; + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SchemaErrorProto { + #[prost(string, optional, tag = "1")] + pub backtrace: Option, + #[prost(oneof = "SchemaErrorInnerProto", tags = "2,3,4,5")] + pub inner: Option, +} + +#[derive(Clone, PartialEq, prost::Oneof)] +pub enum SchemaErrorInnerProto { + #[prost(message, tag = "2")] + AmbiguousReference(AmbiguousReferenceProto), + #[prost(message, tag = "3")] + DuplicateQualifiedField(DuplicateQualifiedFieldProto), + #[prost(message, tag = "4")] + DuplicateUnqualifiedField(DuplicateUnqualifiedFieldProto), + #[prost(message, tag = "5")] + FieldNotFound(FieldNotFoundProto), +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct AmbiguousReferenceProto { + #[prost(message, tag = "1")] + field: Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DuplicateQualifiedFieldProto { + #[prost(message, tag = "1")] + qualifier: Option, + #[prost(string, tag = "2")] + name: String, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DuplicateUnqualifiedFieldProto { + #[prost(string, tag = "1")] + name: String, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FieldNotFoundProto { + #[prost(message, boxed, tag = "1")] + field: Option>, + #[prost(message, repeated, tag = "2")] + valid_fields: Vec, +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ColumnProto { + #[prost(message, tag = "1")] + pub relation: Option, + #[prost(string, tag = "2")] + pub name: String, + // No spans +} + +impl ColumnProto { + pub fn from_column(v: &Column) -> Self { + ColumnProto { + relation: v + .relation + .as_ref() + .map(TableReferenceProto::from_table_reference), + name: v.name.to_string(), + } + } + + pub fn to_column(&self) -> Column { + Column::new( + self.relation.as_ref().map(|v| v.to_table_reference()), + self.name.clone(), + ) + } +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TableReferenceProto { + #[prost(oneof = "TableReferenceInnerProto", tags = "1,2,3")] + pub inner: Option, +} + +#[derive(Clone, PartialEq, prost::Oneof)] +pub enum TableReferenceInnerProto { + #[prost(message, tag = "1")] + Bare(TableReferenceBareProto), + #[prost(message, tag = "2")] + Partial(TableReferencePartialProto), + #[prost(message, tag = "3")] + Full(TableReferenceFullProto), +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TableReferenceBareProto { + #[prost(string, tag = "1")] + pub table: String, +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TableReferencePartialProto { + #[prost(string, tag = "1")] + pub schema: String, + #[prost(string, tag = "2")] + pub table: String, +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TableReferenceFullProto { + #[prost(string, tag = "1")] + pub catalog: String, + #[prost(string, tag = "2")] + pub schema: String, + #[prost(string, tag = "3")] + pub table: String, +} + +impl TableReferenceProto { + pub fn from_table_reference(v: &TableReference) -> Self { + match v { + TableReference::Bare { table } => TableReferenceProto { + inner: Some(TableReferenceInnerProto::Bare(TableReferenceBareProto { + table: table.to_string(), + })), + }, + TableReference::Partial { schema, table } => TableReferenceProto { + inner: Some(TableReferenceInnerProto::Partial( + TableReferencePartialProto { + schema: schema.to_string(), + table: table.to_string(), + }, + )), + }, + TableReference::Full { + catalog, + schema, + table, + } => TableReferenceProto { + inner: Some(TableReferenceInnerProto::Full(TableReferenceFullProto { + catalog: catalog.to_string(), + schema: schema.to_string(), + table: table.to_string(), + })), + }, + } + } + + pub fn to_table_reference(&self) -> TableReference { + let Some(ref inner) = self.inner else { + return TableReference::bare(""); + }; + + match inner { + TableReferenceInnerProto::Bare(msg) => TableReference::Bare { + table: msg.table.clone().into(), + }, + TableReferenceInnerProto::Partial(msg) => TableReference::Partial { + schema: msg.schema.clone().into(), + table: msg.table.clone().into(), + }, + TableReferenceInnerProto::Full(msg) => TableReference::Full { + catalog: msg.catalog.clone().into(), + schema: msg.schema.clone().into(), + table: msg.table.clone().into(), + }, + } + } +} + +impl SchemaErrorProto { + pub fn from_schema_error(err: &SchemaError, backtrace: Option<&String>) -> Self { + match err { + SchemaError::AmbiguousReference { ref field } => SchemaErrorProto { + inner: Some(SchemaErrorInnerProto::AmbiguousReference( + AmbiguousReferenceProto { + field: Some(ColumnProto::from_column(field)), + }, + )), + backtrace: backtrace.cloned(), + }, + SchemaError::DuplicateQualifiedField { qualifier, name } => SchemaErrorProto { + inner: Some(SchemaErrorInnerProto::DuplicateQualifiedField( + DuplicateQualifiedFieldProto { + qualifier: Some(TableReferenceProto::from_table_reference(qualifier)), + name: name.to_string(), + }, + )), + backtrace: backtrace.cloned(), + }, + SchemaError::DuplicateUnqualifiedField { name } => SchemaErrorProto { + inner: Some(SchemaErrorInnerProto::DuplicateUnqualifiedField( + DuplicateUnqualifiedFieldProto { + name: name.to_string(), + }, + )), + backtrace: backtrace.cloned(), + }, + SchemaError::FieldNotFound { + field, + valid_fields, + } => SchemaErrorProto { + inner: Some(SchemaErrorInnerProto::FieldNotFound(FieldNotFoundProto { + field: Some(Box::new(ColumnProto::from_column(&field))), + valid_fields: valid_fields.iter().map(ColumnProto::from_column).collect(), + })), + backtrace: backtrace.cloned(), + }, + } + } + + pub fn to_schema_error(&self) -> (SchemaError, Option) { + let Some(ref inner) = self.inner else { + // Found no better default. + return ( + SchemaError::FieldNotFound { + field: Box::new(Column::new_unqualified("".to_string())), + valid_fields: vec![], + }, + None, + ); + }; + + let err = match inner { + SchemaErrorInnerProto::AmbiguousReference(err) => SchemaError::AmbiguousReference { + field: err + .field + .as_ref() + .map(|v| v.to_column()) + .unwrap_or(Column::new_unqualified("".to_string())), + }, + SchemaErrorInnerProto::DuplicateQualifiedField(err) => { + SchemaError::DuplicateQualifiedField { + qualifier: Box::new( + err.qualifier + .as_ref() + .map(|v| v.to_table_reference()) + .unwrap_or(TableReference::Bare { table: "".into() }), + ), + name: err.name.clone(), + } + } + SchemaErrorInnerProto::DuplicateUnqualifiedField(err) => { + SchemaError::DuplicateUnqualifiedField { + name: err.name.clone(), + } + } + SchemaErrorInnerProto::FieldNotFound(err) => SchemaError::FieldNotFound { + field: Box::new( + err.field + .as_ref() + .map(|v| v.to_column()) + .unwrap_or(Column::new_unqualified("".to_string())), + ), + valid_fields: err.valid_fields.iter().map(|v| v.to_column()).collect(), + }, + }; + (err, self.backtrace.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::common::{Column, SchemaError, TableReference}; + use prost::Message; + + #[test] + fn test_schema_error_roundtrip() { + let test_cases = vec![ + SchemaError::AmbiguousReference { + field: Column::new_unqualified("test_field"), + }, + SchemaError::DuplicateQualifiedField { + qualifier: Box::new(TableReference::bare("table")), + name: "field".to_string(), + }, + SchemaError::DuplicateUnqualifiedField { + name: "field".to_string(), + }, + SchemaError::FieldNotFound { + field: Box::new(Column::new( + Some(TableReference::bare("table")), + "missing_field", + )), + valid_fields: vec![ + Column::new_unqualified("field1"), + Column::new_unqualified("field2"), + ], + }, + ]; + + for original_error in test_cases { + let proto = SchemaErrorProto::from_schema_error( + &original_error, + Some(&"test backtrace".to_string()), + ); + let proto = SchemaErrorProto::decode(proto.encode_to_vec().as_ref()).unwrap(); + let (recovered_error, recovered_backtrace) = proto.to_schema_error(); + + assert_eq!(original_error.to_string(), recovered_error.to_string()); + assert_eq!(recovered_backtrace, Some("test backtrace".to_string())); + + let proto_no_backtrace = SchemaErrorProto::from_schema_error(&original_error, None); + let proto_no_backtrace = + SchemaErrorProto::decode(proto_no_backtrace.encode_to_vec().as_ref()).unwrap(); + let (recovered_error_no_backtrace, recovered_backtrace_no_backtrace) = + proto_no_backtrace.to_schema_error(); + + assert_eq!( + original_error.to_string(), + recovered_error_no_backtrace.to_string() + ); + assert_eq!(recovered_backtrace_no_backtrace, None); + } + } + + #[test] + fn test_malformed_protobuf_message() { + let malformed_proto = SchemaErrorProto { + inner: None, + backtrace: None, + }; + let (recovered_error, _) = malformed_proto.to_schema_error(); + assert!(matches!(recovered_error, SchemaError::FieldNotFound { .. })); + } + + #[test] + fn test_table_reference_roundtrip() { + let test_cases = vec![ + TableReference::bare("table"), + TableReference::partial("schema", "table"), + TableReference::full("catalog", "schema", "table"), + ]; + + for original_ref in test_cases { + let proto = TableReferenceProto::from_table_reference(&original_ref); + let proto = TableReferenceProto::decode(proto.encode_to_vec().as_ref()).unwrap(); + let recovered_ref = proto.to_table_reference(); + + assert_eq!(original_ref.to_string(), recovered_ref.to_string()); + } + } + + #[test] + fn test_column_roundtrip() { + let test_cases = vec![ + Column::new_unqualified("test_field"), + Column::new(Some(TableReference::bare("table")), "field"), + Column::new(Some(TableReference::partial("schema", "table")), "field"), + ]; + + for original_column in test_cases { + let proto = ColumnProto::from_column(&original_column); + let proto = ColumnProto::decode(proto.encode_to_vec().as_ref()).unwrap(); + let recovered_column = proto.to_column(); + + assert_eq!(original_column.name, recovered_column.name); + assert_eq!( + original_column.relation.is_some(), + recovered_column.relation.is_some() + ); + } + } +} diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index c904c62..4771373 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,4 +1,5 @@ use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; +use crate::errors::datafusion_error_to_tonic_status; use crate::flight_service::service::ArrowFlightEndpoint; use crate::plan::ArrowFlightReadExecProtoCodec; use crate::stage_delegation::{ActorContext, StageContext}; @@ -140,22 +141,24 @@ impl ArrowFlightEndpoint { let stream_partitioner = self .partitioner_registry .get_or_create_stream_partitioner(stage_id, actor_idx, plan, partitioning) - .map_err(|err| { - Status::internal(format!("Could not create stream partitioner: {err}")) - })?; + .map_err(|err| datafusion_error_to_tonic_status(&err))?; let stream = stream_partitioner .execute(caller_actor_idx, state.task_ctx()) - .map_err(|err| Status::internal(format!("Cannot get stream partition: {err}")))?; + .map_err(|err| datafusion_error_to_tonic_status(&err))?; - // TODO: error propagation let flight_data_stream = FlightDataEncoderBuilder::new() .with_schema(stream_partitioner.schema()) - .build(stream.map_err(|err| FlightError::ExternalError(Box::new(err)))); - - Ok(Response::new(Box::pin(flight_data_stream.map_err(|err| { - Status::internal(format!("Error during flight stream: {err}")) - })))) + .build(stream.map_err(|err| { + FlightError::Tonic(Box::new(datafusion_error_to_tonic_status(&err))) + })); + + Ok(Response::new(Box::pin(flight_data_stream.map_err( + |err| match err { + FlightError::Tonic(status) => *status, + _ => Status::internal(format!("Error during flight stream: {err}")), + }, + )))) } } diff --git a/src/lib.rs b/src/lib.rs index 3d544ce..d48ecb8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ mod channel_manager; mod composed_extension_codec; +mod errors; mod flight_service; mod plan; mod stage_delegation; diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index d5b174f..dce2472 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -1,5 +1,6 @@ use crate::channel_manager::{ArrowFlightChannel, ChannelManager}; use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; +use crate::errors::tonic_status_to_datafusion_error; use crate::flight_service::{DoGet, DoPut}; use crate::plan::arrow_flight_read_proto::ArrowFlightReadExecProtoCodec; use crate::stage_delegation::{ActorContext, StageContext, StageDelegation}; @@ -189,14 +190,18 @@ impl ExecutionPlan for ArrowFlightReadExec { let stream = client .do_get(ticket.into_request()) .await - .map_err(|err| DataFusionError::External(Box::new(err)))? + .map_err(|err| tonic_status_to_datafusion_error(&err).unwrap_or_else(|| { + DataFusionError::External(Box::new(err)) + }))? .into_inner() .map_err(|err| FlightError::Tonic(Box::new(err))); Ok(FlightRecordBatchStream::new_from_flight_data(stream) - // TODO: propagate the error from the service to here, probably serializing it - // somehow. - .map_err(|err| DataFusionError::External(Box::new(err)))) + .map_err(|err| match err { + FlightError::Tonic(status) => tonic_status_to_datafusion_error(&status) + .unwrap_or_else(|| DataFusionError::External(Box::new(status))), + err => DataFusionError::External(Box::new(err)) + })) }.try_flatten_stream(); Ok(Box::pin(RecordBatchStreamAdapter::new( diff --git a/tests/error_propagation.rs b/tests/error_propagation.rs new file mode 100644 index 0000000..f9ae75f --- /dev/null +++ b/tests/error_propagation.rs @@ -0,0 +1,172 @@ +#[allow(dead_code)] +mod common; + +#[cfg(test)] +mod tests { + use crate::common::localhost::start_localhost_context; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::error::DataFusionError; + use datafusion::execution::{ + FunctionRegistry, SendableRecordBatchStream, SessionStateBuilder, TaskContext, + }; + use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; + use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use datafusion::physical_plan::{ + execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + }; + use datafusion_distributed::{ArrowFlightReadExec, SessionBuilder}; + use datafusion_proto::physical_plan::PhysicalExtensionCodec; + use datafusion_proto::protobuf::proto_error; + use futures::{stream, TryStreamExt}; + use prost::Message; + use std::any::Any; + use std::error::Error; + use std::fmt::Formatter; + use std::sync::Arc; + + #[tokio::test] + async fn test_error_propagation() -> Result<(), Box> { + #[derive(Clone)] + struct CustomSessionBuilder; + impl SessionBuilder for CustomSessionBuilder { + fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder { + let codec: Arc = Arc::new(ErrorExecCodec); + let config = builder.config().get_or_insert_default(); + config.set_extension(Arc::new(codec)); + builder + } + } + let (ctx, _guard) = + start_localhost_context([50050, 50051, 50053], CustomSessionBuilder).await; + + let codec: Arc = Arc::new(ErrorExecCodec); + ctx.state_ref() + .write() + .config_mut() + .set_extension(Arc::new(codec)); + + let mut plan: Arc = Arc::new(ErrorExec::new("something failed")); + + for size in [1, 2, 3] { + plan = Arc::new(ArrowFlightReadExec::new( + plan, + Partitioning::RoundRobinBatch(size), + )); + } + + let stream = execute_stream(plan, ctx.task_ctx())?; + + let Err(err) = stream.try_collect::>().await else { + panic!("Should have failed") + }; + assert_eq!( + DataFusionError::Execution("something failed".to_string()).to_string(), + err.to_string() + ); + + Ok(()) + } + + #[derive(Debug)] + pub struct ErrorExec { + msg: String, + plan_properties: PlanProperties, + } + + impl ErrorExec { + fn new(msg: &str) -> Self { + let schema = Schema::new(vec![Field::new("numbers", DataType::Int64, false)]); + Self { + msg: msg.to_string(), + plan_properties: PlanProperties::new( + EquivalenceProperties::new(Arc::new(schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ), + } + } + } + + impl DisplayAs for ErrorExec { + fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "ErrorExec") + } + } + + impl ExecutionPlan for ErrorExec { + fn name(&self) -> &str { + "ErrorExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion::common::Result> { + Ok(self) + } + + fn execute( + &self, + _: usize, + _: Arc, + ) -> datafusion::common::Result { + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream::iter(vec![Err(DataFusionError::Execution(self.msg.clone()))]), + ))) + } + } + + #[derive(Debug)] + struct ErrorExecCodec; + + #[derive(Clone, PartialEq, ::prost::Message)] + struct ErrorExecProto { + #[prost(string, tag = "1")] + msg: String, + } + + impl PhysicalExtensionCodec for ErrorExecCodec { + fn try_decode( + &self, + buf: &[u8], + _: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> datafusion::common::Result> { + let node = ErrorExecProto::decode(buf).map_err(|err| proto_error(format!("{err}")))?; + Ok(Arc::new(ErrorExec::new(&node.msg))) + } + + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + ) -> datafusion::common::Result<()> { + let Some(plan) = node.as_any().downcast_ref::() else { + return Err(proto_error(format!( + "Expected plan to be of type ErrorExec, but was {}", + node.name() + ))); + }; + ErrorExecProto { + msg: plan.msg.clone(), + } + .encode(buf) + .map_err(|err| proto_error(format!("{err}"))) + } + } +}