diff --git a/datafusion/physical-expr-adapter/src/lib.rs b/datafusion/physical-expr-adapter/src/lib.rs index d7c750e4a1a1c..5ae86f219b6f1 100644 --- a/datafusion/physical-expr-adapter/src/lib.rs +++ b/datafusion/physical-expr-adapter/src/lib.rs @@ -29,6 +29,7 @@ pub mod schema_rewriter; pub use schema_rewriter::{ - DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, - PhysicalExprAdapterFactory, replace_columns_with_literals, + BatchAdapter, BatchAdapterFactory, DefaultPhysicalExprAdapter, + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, + replace_columns_with_literals, }; diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 83727ac092044..b2bed36f0e740 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -24,6 +24,7 @@ use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; +use arrow::array::RecordBatch; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::{ @@ -32,12 +33,15 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode}, }; use datafusion_functions::core::getfield::GetFieldFunc; +use datafusion_physical_expr::PhysicalExprSimplifier; use datafusion_physical_expr::expressions::CastColumnExpr; +use datafusion_physical_expr::projection::{ProjectionExprs, Projector}; use datafusion_physical_expr::{ ScalarFunctionExpr, expressions::{self, Column}, }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use itertools::Itertools; /// Replace column references in the given physical expression with literal values. /// @@ -473,6 +477,141 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { } } +/// Factory for creating [`BatchAdapter`] instances to adapt record batches +/// to a target schema. +/// +/// This binds a target schema and allows creating adapters for different source schemas. +/// It handles: +/// - **Column reordering**: Columns are reordered to match the target schema +/// - **Type casting**: Automatic type conversion (e.g., Int32 to Int64) +/// - **Missing columns**: Nullable columns missing from source are filled with nulls +/// - **Struct field adaptation**: Nested struct fields are recursively adapted +/// +/// ## Examples +/// +/// ```rust +/// use arrow::array::{Int32Array, Int64Array, StringArray, RecordBatch}; +/// use arrow::datatypes::{DataType, Field, Schema}; +/// use datafusion_physical_expr_adapter::BatchAdapterFactory; +/// use std::sync::Arc; +/// +/// // Target schema has different column order and types +/// let target_schema = Arc::new(Schema::new(vec![ +/// Field::new("name", DataType::Utf8, true), +/// Field::new("id", DataType::Int64, false), // Int64 in target +/// Field::new("score", DataType::Float64, true), // Missing from source +/// ])); +/// +/// // Source schema has different column order and Int32 for id +/// let source_schema = Arc::new(Schema::new(vec![ +/// Field::new("id", DataType::Int32, false), // Int32 in source +/// Field::new("name", DataType::Utf8, true), +/// // Note: 'score' column is missing from source +/// ])); +/// +/// // Create factory with target schema +/// let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); +/// +/// // Create adapter for this specific source schema +/// let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap(); +/// +/// // Create a source batch +/// let source_batch = RecordBatch::try_new( +/// source_schema, +/// vec![ +/// Arc::new(Int32Array::from(vec![1, 2, 3])), +/// Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])), +/// ], +/// ).unwrap(); +/// +/// // Adapt the batch to match target schema +/// let adapted = adapter.adapt_batch(&source_batch).unwrap(); +/// +/// assert_eq!(adapted.num_columns(), 3); +/// assert_eq!(adapted.column(0).data_type(), &DataType::Utf8); // name +/// assert_eq!(adapted.column(1).data_type(), &DataType::Int64); // id (cast from Int32) +/// assert_eq!(adapted.column(2).data_type(), &DataType::Float64); // score (filled with nulls) +/// ``` +#[derive(Debug)] +pub struct BatchAdapterFactory { + target_schema: SchemaRef, + expr_adapter_factory: Arc, +} + +impl BatchAdapterFactory { + /// Create a new [`BatchAdapterFactory`] with the given target schema. + pub fn new(target_schema: SchemaRef) -> Self { + let expr_adapter_factory = Arc::new(DefaultPhysicalExprAdapterFactory); + Self { + target_schema, + expr_adapter_factory, + } + } + + /// Set a custom [`PhysicalExprAdapterFactory`] to use when adapting expressions. + /// + /// Use this to customize behavior when adapting batches, e.g. to fill in missing values + /// with defaults instead of nulls. + /// + /// See [`PhysicalExprAdapter`] for more details. + pub fn with_adapter_factory( + self, + factory: Arc, + ) -> Self { + Self { + expr_adapter_factory: factory, + ..self + } + } + + /// Create a new [`BatchAdapter`] for the given source schema. + /// + /// Batches fed into this [`BatchAdapter`] *must* conform to the source schema, + /// no validation is performed at runtime to minimize overheads. + pub fn make_adapter(&self, source_schema: SchemaRef) -> Result { + let expr_adapter = self + .expr_adapter_factory + .create(Arc::clone(&self.target_schema), Arc::clone(&source_schema)); + + let simplifier = PhysicalExprSimplifier::new(&self.target_schema); + + let projection = ProjectionExprs::from_indices( + &(0..self.target_schema.fields().len()).collect_vec(), + &self.target_schema, + ); + + let adapted = projection + .try_map_exprs(|e| simplifier.simplify(expr_adapter.rewrite(e)?))?; + let projector = adapted.make_projector(&source_schema)?; + + Ok(BatchAdapter { projector }) + } +} + +/// Adapter for transforming record batches to match a target schema. +/// +/// Create instances via [`BatchAdapterFactory`]. +/// +/// ## Performance +/// +/// The adapter pre-computes the projection expressions during creation, +/// so the [`adapt_batch`](BatchAdapter::adapt_batch) call is efficient and suitable +/// for use in hot paths like streaming file scans. +#[derive(Debug)] +pub struct BatchAdapter { + projector: Projector, +} + +impl BatchAdapter { + /// Adapt the given record batch to match the target schema. + /// + /// The input batch *must* conform to the source schema used when + /// creating this adapter. + pub fn adapt_batch(&self, batch: &RecordBatch) -> Result { + self.projector.project_batch(batch) + } +} + #[cfg(test)] mod tests { use super::*; @@ -1046,4 +1185,213 @@ mod tests { // with ScalarUDF, which is complex to set up in a unit test. The integration tests in // datafusion/core/tests/parquet/schema_adapter.rs provide better coverage for this functionality. } + + // ============================================================================ + // BatchAdapterFactory and BatchAdapter tests + // ============================================================================ + + #[test] + fn test_batch_adapter_factory_basic() { + // Target schema + let target_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, true), + ])); + + // Source schema with different column order and type + let source_schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Utf8, true), + Field::new("a", DataType::Int32, false), // Int32 -> Int64 + ])); + + let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); + let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap(); + + // Create source batch + let source_batch = RecordBatch::try_new( + Arc::clone(&source_schema), + vec![ + Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])), + Arc::new(Int32Array::from(vec![1, 2, 3])), + ], + ) + .unwrap(); + + let adapted = adapter.adapt_batch(&source_batch).unwrap(); + + // Verify schema matches target + assert_eq!(adapted.num_columns(), 2); + assert_eq!(adapted.schema().field(0).name(), "a"); + assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int64); + assert_eq!(adapted.schema().field(1).name(), "b"); + assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8); + + // Verify data + let col_a = adapted + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col_a.iter().collect_vec(), vec![Some(1), Some(2), Some(3)]); + + let col_b = adapted + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + col_b.iter().collect_vec(), + vec![Some("hello"), None, Some("world")] + ); + } + + #[test] + fn test_batch_adapter_factory_missing_column() { + // Target schema with a column missing from source + let target_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), // exists in source + Field::new("c", DataType::Float64, true), // missing from source + ])); + + let source_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ])); + + let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); + let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap(); + + let source_batch = RecordBatch::try_new( + Arc::clone(&source_schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["x", "y"])), + ], + ) + .unwrap(); + + let adapted = adapter.adapt_batch(&source_batch).unwrap(); + + assert_eq!(adapted.num_columns(), 3); + + // Missing column should be filled with nulls + let col_c = adapted.column(2); + assert_eq!(col_c.data_type(), &DataType::Float64); + assert_eq!(col_c.null_count(), 2); // All nulls + } + + #[test] + fn test_batch_adapter_factory_with_struct() { + // Target has struct with Int64 id + let target_struct_fields: Fields = vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + ] + .into(); + let target_schema = Arc::new(Schema::new(vec![Field::new( + "data", + DataType::Struct(target_struct_fields), + false, + )])); + + // Source has struct with Int32 id + let source_struct_fields: Fields = vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ] + .into(); + let source_schema = Arc::new(Schema::new(vec![Field::new( + "data", + DataType::Struct(source_struct_fields.clone()), + false, + )])); + + let struct_array = StructArray::new( + source_struct_fields, + vec![ + Arc::new(Int32Array::from(vec![10, 20])) as _, + Arc::new(StringArray::from(vec!["a", "b"])) as _, + ], + None, + ); + + let source_batch = RecordBatch::try_new( + Arc::clone(&source_schema), + vec![Arc::new(struct_array)], + ) + .unwrap(); + + let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); + let adapter = factory.make_adapter(source_schema).unwrap(); + let adapted = adapter.adapt_batch(&source_batch).unwrap(); + + let result_struct = adapted + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // Verify id was cast to Int64 + let id_col = result_struct.column_by_name("id").unwrap(); + assert_eq!(id_col.data_type(), &DataType::Int64); + let id_values = id_col.as_any().downcast_ref::().unwrap(); + assert_eq!(id_values.iter().collect_vec(), vec![Some(10), Some(20)]); + } + + #[test] + fn test_batch_adapter_factory_identity() { + // When source and target schemas are identical, should pass through efficiently + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ])); + + let factory = BatchAdapterFactory::new(Arc::clone(&schema)); + let adapter = factory.make_adapter(Arc::clone(&schema)).unwrap(); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(); + + let adapted = adapter.adapt_batch(&batch).unwrap(); + + assert_eq!(adapted.num_columns(), 2); + assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int32); + assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8); + } + + #[test] + fn test_batch_adapter_factory_reuse() { + // Factory can create multiple adapters for different source schemas + let target_schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int64, false), + Field::new("y", DataType::Utf8, true), + ])); + + let factory = BatchAdapterFactory::new(Arc::clone(&target_schema)); + + // First source schema + let source1 = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, true), + ])); + let adapter1 = factory.make_adapter(source1).unwrap(); + + // Second source schema (different order) + let source2 = Arc::new(Schema::new(vec![ + Field::new("y", DataType::Utf8, true), + Field::new("x", DataType::Int64, false), + ])); + let adapter2 = factory.make_adapter(source2).unwrap(); + + // Both should work correctly + assert!(format!("{:?}", adapter1).contains("BatchAdapter")); + assert!(format!("{:?}", adapter2).contains("BatchAdapter")); + } }