diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 693adc6da03a..e2ee1be7d732 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -517,14 +517,17 @@ pub fn parse_protobuf_file_scan_config( // Remove partition columns from the schema after recreating table_partition_cols // because the partition columns are not in the file. They are present to allow // the partition column types to be reconstructed after serde. - let file_schema = Arc::new(Schema::new( - schema - .fields() - .iter() - .filter(|field| !table_partition_cols.contains(field)) - .cloned() - .collect::>(), - )); + let file_schema = Arc::new( + Schema::new( + schema + .fields() + .iter() + .filter(|field| !table_partition_cols.contains(field)) + .cloned() + .collect::>(), + ) + .with_metadata(schema.metadata.clone()), + ); let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 28c3f84c5c7e..19a76de3e5b0 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -523,7 +523,11 @@ pub fn serialize_file_scan_config( .cloned() .collect::>(); fields.extend(conf.table_partition_cols.iter().cloned()); - let schema = Arc::new(arrow::datatypes::Schema::new(fields.clone())); + + let schema = Arc::new( + arrow::datatypes::Schema::new(fields.clone()) + .with_metadata(conf.file_schema.metadata.clone()), + ); Ok(protobuf::FileScanExecConf { file_groups, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 825fc8e7bf64..b93d0d3c4e7c 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -16,6 +16,7 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::fmt::{Display, Formatter}; use std::sync::Arc; @@ -42,9 +43,11 @@ use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; use datafusion::datasource::empty::EmptyTable; use datafusion::datasource::file_format::csv::CsvSink; -use datafusion::datasource::file_format::json::JsonSink; +use datafusion::datasource::file_format::json::{JsonFormat, JsonSink}; use datafusion::datasource::file_format::parquet::ParquetSink; -use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, PartitionedFile, +}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileGroup, @@ -2221,3 +2224,41 @@ async fn roundtrip_memory_source() -> Result<()> { .await?; roundtrip_test(plan) } + +#[tokio::test] +async fn roundtrip_listing_table_with_schema_metadata() -> Result<()> { + let ctx = SessionContext::new(); + let file_format = JsonFormat::default(); + let table_partition_cols = vec![("part".to_owned(), DataType::Int64)]; + let data = "../core/tests/data/partitioned_table_json"; + let listing_table_url = ListingTableUrl::parse(data)?; + let listing_options = ListingOptions::new(Arc::new(file_format)) + .with_table_partition_cols(table_partition_cols); + + let config = ListingTableConfig::new(listing_table_url) + .with_listing_options(listing_options) + .infer_schema(&ctx.state()) + .await?; + + // Decorate metadata onto the inferred ListingTable schema + let schema_with_meta = config + .file_schema + .clone() + .map(|s| { + let mut meta: HashMap = HashMap::new(); + meta.insert("foo.bar".to_string(), "baz".to_string()); + s.as_ref().clone().with_metadata(meta) + }) + .expect("Must decorate metadata"); + + let config = config.with_schema(Arc::new(schema_with_meta)); + ctx.register_table("hive_style", Arc::new(ListingTable::try_new(config)?))?; + + let plan = ctx + .sql("select * from hive_style limit 1") + .await? + .create_physical_plan() + .await?; + + roundtrip_test(plan) +}