diff --git a/src/config_extension_ext.rs b/src/config_extension_ext.rs new file mode 100644 index 0000000..41f5141 --- /dev/null +++ b/src/config_extension_ext.rs @@ -0,0 +1,466 @@ +use datafusion::common::{internal_datafusion_err, DataFusionError}; +use datafusion::config::ConfigExtension; +use datafusion::execution::{SessionState, SessionStateBuilder}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use delegate::delegate; +use http::{HeaderMap, HeaderName}; +use std::error::Error; +use std::str::FromStr; +use std::sync::Arc; + +const FLIGHT_METADATA_CONFIG_PREFIX: &str = "x-datafusion-distributed-config-"; + +/// Extension trait for `SessionConfig` to add support for propagating [ConfigExtension]s across +/// network calls. +pub trait ConfigExtensionExt { + /// Adds the provided [ConfigExtension] to the distributed context. The [ConfigExtension] will + /// be serialized using gRPC metadata and sent across tasks. Users are expected to call this + /// method with their own extensions to be able to access them in any place in the + /// plan. + /// + /// This method also adds the provided [ConfigExtension] to the current session option + /// extensions, the same as calling [SessionConfig::with_option_extension]. + /// + /// Example: + /// + /// ```rust + /// # use async_trait::async_trait; + /// # use datafusion::common::{extensions_options, DataFusionError}; + /// # use datafusion::config::ConfigExtension; + /// # use datafusion::execution::SessionState; + /// # use datafusion::prelude::SessionConfig; + /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder}; + /// + /// extensions_options! { + /// pub struct CustomExtension { + /// pub foo: String, default = "".to_string() + /// pub bar: usize, default = 0 + /// pub baz: bool, default = false + /// } + /// } + /// + /// impl ConfigExtension for CustomExtension { + /// const PREFIX: &'static str = "custom"; + /// } + /// + /// let mut config = SessionConfig::new(); + /// let mut opt = CustomExtension::default(); + /// // Now, the CustomExtension will be able to cross network boundaries. Upon making an Arrow + /// // Flight request, it will be sent through gRPC metadata. + /// config.add_distributed_option_extension(opt).unwrap(); + /// + /// struct MyCustomSessionBuilder; + /// + /// #[async_trait] + /// impl SessionBuilder for MyCustomSessionBuilder { + /// async fn session_state(&self, mut state: SessionState) -> Result { + /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will + /// // know how to deserialize the CustomExtension from the gRPC metadata. + /// state.retrieve_distributed_option_extension::()?; + /// Ok(state) + /// } + /// } + /// ``` + fn add_distributed_option_extension( + &mut self, + t: T, + ) -> Result<(), DataFusionError>; + + /// Gets the specified [ConfigExtension] from the distributed context and adds it to + /// the [SessionConfig::options] extensions. The function will build a new [ConfigExtension] + /// out of the Arrow Flight gRPC metadata present in the [SessionConfig] and will propagate it + /// to the extension options. + /// Example: + /// + /// ```rust + /// # use async_trait::async_trait; + /// # use datafusion::common::{extensions_options, DataFusionError}; + /// # use datafusion::config::ConfigExtension; + /// # use datafusion::execution::SessionState; + /// # use datafusion::prelude::SessionConfig; + /// # use datafusion_distributed::{ConfigExtensionExt, SessionBuilder}; + /// + /// extensions_options! { + /// pub struct CustomExtension { + /// pub foo: String, default = "".to_string() + /// pub bar: usize, default = 0 + /// pub baz: bool, default = false + /// } + /// } + /// + /// impl ConfigExtension for CustomExtension { + /// const PREFIX: &'static str = "custom"; + /// } + /// + /// let mut config = SessionConfig::new(); + /// let mut opt = CustomExtension::default(); + /// // Now, the CustomExtension will be able to cross network boundaries. Upon making an Arrow + /// // Flight request, it will be sent through gRPC metadata. + /// config.add_distributed_option_extension(opt).unwrap(); + /// + /// struct MyCustomSessionBuilder; + /// + /// #[async_trait] + /// impl SessionBuilder for MyCustomSessionBuilder { + /// async fn session_state(&self, mut state: SessionState) -> Result { + /// // while providing this MyCustomSessionBuilder to an Arrow Flight endpoint, it will + /// // know how to deserialize the CustomExtension from the gRPC metadata. + /// state.retrieve_distributed_option_extension::()?; + /// Ok(state) + /// } + /// } + /// ``` + fn retrieve_distributed_option_extension( + &mut self, + ) -> Result<(), DataFusionError>; +} + +impl ConfigExtensionExt for SessionConfig { + fn add_distributed_option_extension( + &mut self, + t: T, + ) -> Result<(), DataFusionError> { + fn parse_err(err: impl Error) -> DataFusionError { + DataFusionError::Internal(format!("Failed to add config extension: {err}")) + } + let mut meta = HeaderMap::new(); + + for entry in t.entries() { + if let Some(value) = entry.value { + meta.insert( + HeaderName::from_str(&format!( + "{}{}.{}", + FLIGHT_METADATA_CONFIG_PREFIX, + T::PREFIX, + entry.key + )) + .map_err(parse_err)?, + value.parse().map_err(parse_err)?, + ); + } + } + let flight_metadata = ContextGrpcMetadata(meta); + match self.get_extension::() { + None => self.set_extension(Arc::new(flight_metadata)), + Some(prev) => { + let prev = prev.as_ref().clone(); + self.set_extension(Arc::new(prev.merge(flight_metadata))) + } + } + self.options_mut().extensions.insert(t); + Ok(()) + } + + fn retrieve_distributed_option_extension( + &mut self, + ) -> Result<(), DataFusionError> { + let Some(flight_metadata) = self.get_extension::() else { + return Ok(()); + }; + + let mut result = T::default(); + let mut found_some = false; + for (k, v) in flight_metadata.0.iter() { + let key = k.as_str().trim_start_matches(FLIGHT_METADATA_CONFIG_PREFIX); + let prefix = format!("{}.", T::PREFIX); + if key.starts_with(&prefix) { + found_some = true; + result.set( + &key.trim_start_matches(&prefix), + v.to_str().map_err(|err| { + internal_datafusion_err!("Cannot parse header value: {err}") + })?, + )?; + } + } + if !found_some { + return Ok(()); + } + self.options_mut().extensions.insert(result); + Ok(()) + } +} + +impl ConfigExtensionExt for SessionStateBuilder { + delegate! { + to self.config().get_or_insert_default() { + fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + } + } +} + +impl ConfigExtensionExt for SessionState { + delegate! { + to self.config_mut() { + fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + } + } +} + +impl ConfigExtensionExt for SessionContext { + delegate! { + to self.state_ref().write().config_mut() { + fn add_distributed_option_extension(&mut self, t: T) -> Result<(), DataFusionError>; + fn retrieve_distributed_option_extension(&mut self) -> Result<(), DataFusionError>; + } + } +} + +#[derive(Clone, Debug, Default)] +pub(crate) struct ContextGrpcMetadata(pub HeaderMap); + +impl ContextGrpcMetadata { + pub(crate) fn from_headers(metadata: HeaderMap) -> Self { + let mut new = HeaderMap::new(); + for (k, v) in metadata.into_iter() { + let Some(k) = k else { continue }; + if k.as_str().starts_with(FLIGHT_METADATA_CONFIG_PREFIX) { + new.insert(k, v); + } + } + Self(new) + } + + fn merge(mut self, other: Self) -> Self { + for (k, v) in other.0.into_iter() { + let Some(k) = k else { continue }; + self.0.insert(k, v); + } + self + } +} + +#[cfg(test)] +mod tests { + use crate::config_extension_ext::ContextGrpcMetadata; + use crate::ConfigExtensionExt; + use datafusion::common::extensions_options; + use datafusion::config::ConfigExtension; + use datafusion::prelude::SessionConfig; + use http::{HeaderMap, HeaderName, HeaderValue}; + use std::str::FromStr; + + #[test] + fn test_propagation() -> Result<(), Box> { + let mut config = SessionConfig::new(); + + let mut opt = CustomExtension::default(); + opt.foo = "foo".to_string(); + opt.bar = 1; + opt.baz = true; + + config.add_distributed_option_extension(opt)?; + + let mut new_config = SessionConfig::new(); + new_config.set_extension(config.get_extension::().unwrap()); + new_config.retrieve_distributed_option_extension::()?; + + let opt = get_ext::(&config); + let new_opt = get_ext::(&new_config); + + assert_eq!(new_opt.foo, opt.foo); + assert_eq!(new_opt.bar, opt.bar); + assert_eq!(new_opt.baz, opt.baz); + + Ok(()) + } + + #[test] + fn test_add_extension_with_empty_values() -> Result<(), Box> { + let mut config = SessionConfig::new(); + let opt = CustomExtension::default(); + + config.add_distributed_option_extension(opt)?; + + let flight_metadata = config.get_extension::(); + assert!(flight_metadata.is_some()); + + let metadata = &flight_metadata.unwrap().0; + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.foo")); + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.bar")); + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.baz")); + + let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); + assert_eq!(get("x-datafusion-distributed-config-custom.foo"), ""); + assert_eq!(get("x-datafusion-distributed-config-custom.bar"), "0"); + assert_eq!(get("x-datafusion-distributed-config-custom.baz"), "false"); + + Ok(()) + } + + #[test] + fn test_new_extension_overwrites_previous() -> Result<(), Box> { + let mut config = SessionConfig::new(); + + let mut opt1 = CustomExtension::default(); + opt1.foo = "first".to_string(); + config.add_distributed_option_extension(opt1)?; + + let mut opt2 = CustomExtension::default(); + opt2.bar = 42; + config.add_distributed_option_extension(opt2)?; + + let flight_metadata = config.get_extension::().unwrap(); + let metadata = &flight_metadata.0; + + let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); + assert_eq!(get("x-datafusion-distributed-config-custom.foo"), ""); + assert_eq!(get("x-datafusion-distributed-config-custom.bar"), "42"); + assert_eq!(get("x-datafusion-distributed-config-custom.baz"), "false"); + + Ok(()) + } + + #[test] + fn test_propagate_no_metadata() -> Result<(), Box> { + let mut config = SessionConfig::new(); + + config.retrieve_distributed_option_extension::()?; + + let extension = config.options().extensions.get::(); + assert!(extension.is_none()); + + Ok(()) + } + + #[test] + fn test_propagate_no_matching_prefix() -> Result<(), Box> { + let mut config = SessionConfig::new(); + let mut header_map = HeaderMap::new(); + header_map.insert( + HeaderName::from_str("x-datafusion-distributed-other.setting").unwrap(), + HeaderValue::from_str("value").unwrap(), + ); + + let flight_metadata = ContextGrpcMetadata::from_headers(header_map); + config.set_extension(std::sync::Arc::new(flight_metadata)); + config.retrieve_distributed_option_extension::()?; + + let extension = config.options().extensions.get::(); + assert!(extension.is_none()); + + Ok(()) + } + + #[test] + fn test_multiple_extensions_different_prefixes() -> Result<(), Box> { + let mut config = SessionConfig::new(); + + let mut custom_opt = CustomExtension::default(); + custom_opt.foo = "custom_value".to_string(); + custom_opt.bar = 123; + + let mut another_opt = AnotherExtension::default(); + another_opt.setting1 = "other".to_string(); + another_opt.setting2 = 456; + + config.add_distributed_option_extension(custom_opt)?; + config.add_distributed_option_extension(another_opt)?; + + let flight_metadata = config.get_extension::().unwrap(); + let metadata = &flight_metadata.0; + + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.foo")); + assert!(metadata.contains_key("x-datafusion-distributed-config-custom.bar")); + assert!(metadata.contains_key("x-datafusion-distributed-config-another.setting1")); + assert!(metadata.contains_key("x-datafusion-distributed-config-another.setting2")); + + let get = |key: &str| metadata.get(key).unwrap().to_str().unwrap(); + + assert_eq!( + get("x-datafusion-distributed-config-custom.foo"), + "custom_value" + ); + assert_eq!(get("x-datafusion-distributed-config-custom.bar"), "123"); + assert_eq!( + get("x-datafusion-distributed-config-another.setting1"), + "other" + ); + assert_eq!( + get("x-datafusion-distributed-config-another.setting2"), + "456" + ); + + let mut new_config = SessionConfig::new(); + new_config.set_extension(flight_metadata); + new_config.retrieve_distributed_option_extension::()?; + new_config.retrieve_distributed_option_extension::()?; + + let propagated_custom = get_ext::(&new_config); + let propagated_another = get_ext::(&new_config); + + assert_eq!(propagated_custom.foo, "custom_value"); + assert_eq!(propagated_custom.bar, 123); + assert_eq!(propagated_another.setting1, "other"); + assert_eq!(propagated_another.setting2, 456); + + Ok(()) + } + + #[test] + fn test_invalid_header_name() { + let mut config = SessionConfig::new(); + let extension = InvalidExtension::default(); + + let result = config.add_distributed_option_extension(extension); + assert!(result.is_err()); + } + + #[test] + fn test_invalid_header_value() { + let mut config = SessionConfig::new(); + let extension = InvalidValueExtension::default(); + + let result = config.add_distributed_option_extension(extension); + assert!(result.is_err()); + } + + extensions_options! { + pub struct CustomExtension { + pub foo: String, default = "".to_string() + pub bar: usize, default = 0 + pub baz: bool, default = false + } + } + + impl ConfigExtension for CustomExtension { + const PREFIX: &'static str = "custom"; + } + + extensions_options! { + pub struct AnotherExtension { + pub setting1: String, default = "default1".to_string() + pub setting2: usize, default = 42 + } + } + + impl ConfigExtension for AnotherExtension { + const PREFIX: &'static str = "another"; + } + + extensions_options! { + pub struct InvalidExtension { + pub key_with_spaces: String, default = "value".to_string() + } + } + + impl ConfigExtension for InvalidExtension { + const PREFIX: &'static str = "invalid key with spaces"; + } + + extensions_options! { + pub struct InvalidValueExtension { + pub key: String, default = "\u{0000}invalid\u{0001}".to_string() + } + } + + impl ConfigExtension for InvalidValueExtension { + const PREFIX: &'static str = "invalid_value"; + } + + fn get_ext(cfg: &SessionConfig) -> &T { + cfg.options().extensions.get::().unwrap() + } +} diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 38b7d25..2c14bba 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -1,4 +1,6 @@ +use super::service::StageKey; use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; +use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::datafusion_error_to_tonic_status; use crate::flight_service::service::ArrowFlightEndpoint; use crate::plan::{DistributedCodec, PartitionGroup}; @@ -10,13 +12,13 @@ use arrow_flight::flight_service_server::FlightService; use arrow_flight::Ticket; use datafusion::execution::{SessionState, SessionStateBuilder}; use datafusion::optimizer::OptimizerConfig; +use datafusion::prelude::SessionConfig; use futures::TryStreamExt; use prost::Message; use std::sync::Arc; +use tonic::metadata::MetadataMap; use tonic::{Request, Response, Status}; -use super::service::StageKey; - #[derive(Clone, PartialEq, ::prost::Message)] pub struct DoGet { /// The ExecutionStage that we are going to execute @@ -41,14 +43,15 @@ impl ArrowFlightEndpoint { &self, request: Request, ) -> Result::DoGetStream>, Status> { - let Ticket { ticket } = request.into_inner(); + let (metadata, _ext, ticket) = request.into_parts(); + let Ticket { ticket } = ticket; let doget = DoGet::decode(ticket).map_err(|err| { Status::invalid_argument(format!("Cannot decode DoGet message: {err}")) })?; let partition = doget.partition as usize; let task_number = doget.task_number as usize; - let (mut state, stage) = self.get_state_and_stage(doget).await?; + let (mut state, stage) = self.get_state_and_stage(doget, metadata).await?; // find out which partition group we are executing let task = stage @@ -87,6 +90,7 @@ impl ArrowFlightEndpoint { async fn get_state_and_stage( &self, doget: DoGet, + metadata: MetadataMap, ) -> Result<(SessionState, Arc), Status> { let key = doget .stage_key @@ -102,9 +106,16 @@ impl ArrowFlightEndpoint { .stage_proto .ok_or(Status::invalid_argument("DoGet is missing the stage proto"))?; + let mut config = SessionConfig::default(); + config.set_extension(Arc::new(ContextGrpcMetadata::from_headers( + metadata.into_headers(), + ))); + let state_builder = SessionStateBuilder::new() .with_runtime_env(Arc::clone(&self.runtime)) + .with_config(config) .with_default_features(); + let state_builder = self .session_builder .session_state_builder(state_builder) diff --git a/src/lib.rs b/src/lib.rs index 826bdf7..40515a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ mod channel_manager; mod common; mod composed_extension_codec; +mod config_extension_ext; mod errors; mod flight_service; mod physical_optimizer; @@ -15,6 +16,7 @@ mod user_provided_codec; pub mod test_utils; pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver}; +pub use config_extension_ext::ConfigExtensionExt; pub use flight_service::{ArrowFlightEndpoint, NoopSessionBuilder, SessionBuilder}; pub use physical_optimizer::DistributedPhysicalOptimizerRule; pub use plan::ArrowFlightReadExec; diff --git a/src/plan/arrow_flight_read.rs b/src/plan/arrow_flight_read.rs index 62c5e76..36d7786 100644 --- a/src/plan/arrow_flight_read.rs +++ b/src/plan/arrow_flight_read.rs @@ -1,6 +1,7 @@ use super::combined::CombinedRecordBatchStream; use crate::channel_manager::ChannelManager; use crate::composed_extension_codec::ComposedPhysicalExtensionCodec; +use crate::config_extension_ext::ContextGrpcMetadata; use crate::errors::tonic_status_to_datafusion_error; use crate::flight_service::{DoGet, StageKey}; use crate::plan::DistributedCodec; @@ -19,10 +20,13 @@ use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; use futures::{future, TryFutureExt, TryStreamExt}; +use http::Extensions; use prost::Message; use std::any::Any; use std::fmt::Formatter; use std::sync::Arc; +use tonic::metadata::MetadataMap; +use tonic::Request; use url::Url; /// This node has two variants. @@ -173,6 +177,10 @@ impl ExecutionPlan for ArrowFlightReadExec { this.stage_num ))?; + let flight_metadata = context + .session_config() + .get_extension::(); + let mut combined_codec = ComposedPhysicalExtensionCodec::default(); combined_codec.push(DistributedCodec {}); if let Some(ref user_codec) = get_user_codec(context.session_config()) { @@ -195,6 +203,10 @@ impl ExecutionPlan for ArrowFlightReadExec { let channel_manager_capture = channel_manager.clone(); let schema = schema.clone(); let query_id = query_id.clone(); + let flight_metadata = flight_metadata + .as_ref() + .map(|v| v.as_ref().clone()) + .unwrap_or_default(); let key = StageKey { query_id, stage_id: child_stage_num, @@ -218,8 +230,14 @@ impl ExecutionPlan for ArrowFlightReadExec { ticket: ticket_bytes, }; - stream_from_stage_task(ticket, &url, schema.clone(), &channel_manager_capture) - .await + stream_from_stage_task( + ticket, + flight_metadata, + &url, + schema.clone(), + &channel_manager_capture, + ) + .await } }); @@ -240,12 +258,19 @@ impl ExecutionPlan for ArrowFlightReadExec { async fn stream_from_stage_task( ticket: Ticket, + metadata: ContextGrpcMetadata, url: &Url, schema: SchemaRef, channel_manager: &ChannelManager, ) -> Result { let channel = channel_manager.get_channel_for_url(url).await?; + let ticket = Request::from_parts( + MetadataMap::from_headers(metadata.0), + Extensions::default(), + ticket, + ); + let mut client = FlightServiceClient::new(channel); let stream = client .do_get(ticket) diff --git a/tests/custom_config_extension.rs b/tests/custom_config_extension.rs new file mode 100644 index 0000000..d671d90 --- /dev/null +++ b/tests/custom_config_extension.rs @@ -0,0 +1,180 @@ +#[cfg(all(feature = "integration", test))] +mod tests { + use async_trait::async_trait; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::{extensions_options, internal_err}; + use datafusion::config::ConfigExtension; + use datafusion::error::DataFusionError; + use datafusion::execution::{ + FunctionRegistry, SendableRecordBatchStream, SessionState, TaskContext, + }; + use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; + use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; + use datafusion::physical_plan::stream::RecordBatchStreamAdapter; + use datafusion::physical_plan::{ + execute_stream, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + }; + use datafusion_distributed::test_utils::localhost::start_localhost_context; + use datafusion_distributed::{add_user_codec, ConfigExtensionExt}; + use datafusion_distributed::{ + ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, + }; + use datafusion_proto::physical_plan::PhysicalExtensionCodec; + use futures::TryStreamExt; + use prost::Message; + use std::any::Any; + use std::fmt::Formatter; + use std::sync::Arc; + + #[tokio::test] + async fn custom_config_extension() -> Result<(), Box> { + #[derive(Clone)] + struct CustomSessionBuilder; + + #[async_trait] + impl SessionBuilder for CustomSessionBuilder { + async fn session_state( + &self, + mut state: SessionState, + ) -> Result { + state.retrieve_distributed_option_extension::()?; + add_user_codec(state.config_mut(), CustomConfigExtensionRequiredExecCodec); + Ok(state) + } + } + + let (mut ctx, _guard) = start_localhost_context(3, CustomSessionBuilder).await; + add_user_codec(&mut ctx, CustomConfigExtensionRequiredExecCodec); + ctx.add_distributed_option_extension(CustomExtension { + foo: "foo".to_string(), + bar: 1, + baz: true, + })?; + + let mut plan: Arc = Arc::new(CustomConfigExtensionRequiredExec::new()); + + for size in [1, 2, 3] { + plan = Arc::new(ArrowFlightReadExec::new_pending( + plan, + Partitioning::RoundRobinBatch(size), + )); + } + + let plan = DistributedPhysicalOptimizerRule::default().distribute_plan(plan)?; + let stream = execute_stream(Arc::new(plan), ctx.task_ctx())?; + // It should not fail. + stream.try_collect::>().await?; + + Ok(()) + } + + extensions_options! { + pub struct CustomExtension { + pub foo: String, default = "".to_string() + pub bar: usize, default = 0 + pub baz: bool, default = false + } + } + + impl ConfigExtension for CustomExtension { + const PREFIX: &'static str = "custom"; + } + + #[derive(Debug)] + pub struct CustomConfigExtensionRequiredExec { + plan_properties: PlanProperties, + } + + impl CustomConfigExtensionRequiredExec { + fn new() -> Self { + let schema = Schema::new(vec![Field::new("numbers", DataType::Int64, false)]); + Self { + plan_properties: PlanProperties::new( + EquivalenceProperties::new(Arc::new(schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ), + } + } + } + + impl DisplayAs for CustomConfigExtensionRequiredExec { + fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "CustomConfigExtensionRequiredExec") + } + } + + impl ExecutionPlan for CustomConfigExtensionRequiredExec { + fn name(&self) -> &str { + "CustomConfigExtensionRequiredExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion::common::Result> { + Ok(self) + } + + fn execute( + &self, + _: usize, + ctx: Arc, + ) -> datafusion::common::Result { + if ctx + .session_config() + .options() + .extensions + .get::() + .is_none() + { + return internal_err!("CustomExtension not found in context"); + } + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::empty(), + ))) + } + } + + #[derive(Debug)] + struct CustomConfigExtensionRequiredExecCodec; + + #[derive(Clone, PartialEq, ::prost::Message)] + struct CustomConfigExtensionRequiredExecProto {} + + impl PhysicalExtensionCodec for CustomConfigExtensionRequiredExecCodec { + fn try_decode( + &self, + _buf: &[u8], + _: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> datafusion::common::Result> { + Ok(Arc::new(CustomConfigExtensionRequiredExec::new())) + } + + fn try_encode( + &self, + _node: Arc, + buf: &mut Vec, + ) -> datafusion::common::Result<()> { + CustomConfigExtensionRequiredExecProto::default() + .encode(buf) + .unwrap(); + Ok(()) + } + } +}