Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions datafusion/physical-expr-adapter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
pub mod schema_rewriter;

pub use schema_rewriter::{
DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter,
PhysicalExprAdapterFactory, replace_columns_with_literals,
BatchAdapter, BatchAdapterFactory, DefaultPhysicalExprAdapter,
DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory,
replace_columns_with_literals,
};
348 changes: 348 additions & 0 deletions datafusion/physical-expr-adapter/src/schema_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Arc;

use arrow::array::RecordBatch;
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion_common::{
Expand All @@ -32,12 +33,15 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode},
};
use datafusion_functions::core::getfield::GetFieldFunc;
use datafusion_physical_expr::PhysicalExprSimplifier;
use datafusion_physical_expr::expressions::CastColumnExpr;
use datafusion_physical_expr::projection::{ProjectionExprs, Projector};
use datafusion_physical_expr::{
ScalarFunctionExpr,
expressions::{self, Column},
};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use itertools::Itertools;

/// Replace column references in the given physical expression with literal values.
///
Expand Down Expand Up @@ -473,6 +477,141 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> {
}
}

/// Factory for creating [`BatchAdapter`] instances to adapt record batches
/// to a target schema.
///
/// This binds a target schema and allows creating adapters for different source schemas.
/// It handles:
/// - **Column reordering**: Columns are reordered to match the target schema
/// - **Type casting**: Automatic type conversion (e.g., Int32 to Int64)
/// - **Missing columns**: Nullable columns missing from source are filled with nulls
/// - **Struct field adaptation**: Nested struct fields are recursively adapted
///
/// ## Examples
///
/// ```rust
/// use arrow::array::{Int32Array, Int64Array, StringArray, RecordBatch};
/// use arrow::datatypes::{DataType, Field, Schema};
/// use datafusion_physical_expr_adapter::BatchAdapterFactory;
/// use std::sync::Arc;
///
/// // Target schema has different column order and types
/// let target_schema = Arc::new(Schema::new(vec![
/// Field::new("name", DataType::Utf8, true),
/// Field::new("id", DataType::Int64, false), // Int64 in target
/// Field::new("score", DataType::Float64, true), // Missing from source
/// ]));
///
/// // Source schema has different column order and Int32 for id
/// let source_schema = Arc::new(Schema::new(vec![
/// Field::new("id", DataType::Int32, false), // Int32 in source
/// Field::new("name", DataType::Utf8, true),
/// // Note: 'score' column is missing from source
/// ]));
///
/// // Create factory with target schema
/// let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
///
/// // Create adapter for this specific source schema
/// let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap();
///
/// // Create a source batch
/// let source_batch = RecordBatch::try_new(
/// source_schema,
/// vec![
/// Arc::new(Int32Array::from(vec![1, 2, 3])),
/// Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol"])),
/// ],
/// ).unwrap();
///
/// // Adapt the batch to match target schema
/// let adapted = adapter.adapt_batch(&source_batch).unwrap();
///
/// assert_eq!(adapted.num_columns(), 3);
/// assert_eq!(adapted.column(0).data_type(), &DataType::Utf8); // name
/// assert_eq!(adapted.column(1).data_type(), &DataType::Int64); // id (cast from Int32)
/// assert_eq!(adapted.column(2).data_type(), &DataType::Float64); // score (filled with nulls)
/// ```
#[derive(Debug)]
pub struct BatchAdapterFactory {
target_schema: SchemaRef,
expr_adapter_factory: Arc<dyn PhysicalExprAdapterFactory>,
}

impl BatchAdapterFactory {
/// Create a new [`BatchAdapterFactory`] with the given target schema.
pub fn new(target_schema: SchemaRef) -> Self {
let expr_adapter_factory = Arc::new(DefaultPhysicalExprAdapterFactory);
Self {
target_schema,
expr_adapter_factory,
}
}

/// Set a custom [`PhysicalExprAdapterFactory`] to use when adapting expressions.
///
/// Use this to customize behavior when adapting batches, e.g. to fill in missing values
/// with defaults instead of nulls.
///
/// See [`PhysicalExprAdapter`] for more details.
pub fn with_adapter_factory(
self,
factory: Arc<dyn PhysicalExprAdapterFactory>,
) -> Self {
Self {
expr_adapter_factory: factory,
..self
}
}

/// Create a new [`BatchAdapter`] for the given source schema.
///
/// Batches fed into this [`BatchAdapter`] *must* conform to the source schema,
/// no validation is performed at runtime to minimize overheads.
pub fn make_adapter(&self, source_schema: SchemaRef) -> Result<BatchAdapter> {
let expr_adapter = self
.expr_adapter_factory
.create(Arc::clone(&self.target_schema), Arc::clone(&source_schema));

let simplifier = PhysicalExprSimplifier::new(&self.target_schema);

let projection = ProjectionExprs::from_indices(
&(0..self.target_schema.fields().len()).collect_vec(),
&self.target_schema,
);

let adapted = projection
.try_map_exprs(|e| simplifier.simplify(expr_adapter.rewrite(e)?))?;
let projector = adapted.make_projector(&source_schema)?;

Ok(BatchAdapter { projector })
}
}

/// Adapter for transforming record batches to match a target schema.
///
/// Create instances via [`BatchAdapterFactory`].
///
/// ## Performance
///
/// The adapter pre-computes the projection expressions during creation,
/// so the [`adapt_batch`](BatchAdapter::adapt_batch) call is efficient and suitable
/// for use in hot paths like streaming file scans.
#[derive(Debug)]
pub struct BatchAdapter {
projector: Projector,
}

impl BatchAdapter {
/// Adapt the given record batch to match the target schema.
///
/// The input batch *must* conform to the source schema used when
/// creating this adapter.
pub fn adapt_batch(&self, batch: &RecordBatch) -> Result<RecordBatch> {
self.projector.project_batch(batch)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -1046,4 +1185,213 @@ mod tests {
// with ScalarUDF, which is complex to set up in a unit test. The integration tests in
// datafusion/core/tests/parquet/schema_adapter.rs provide better coverage for this functionality.
}

// ============================================================================
// BatchAdapterFactory and BatchAdapter tests
// ============================================================================

#[test]
fn test_batch_adapter_factory_basic() {
// Target schema
let target_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Utf8, true),
]));

// Source schema with different column order and type
let source_schema = Arc::new(Schema::new(vec![
Field::new("b", DataType::Utf8, true),
Field::new("a", DataType::Int32, false), // Int32 -> Int64
]));

let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap();

// Create source batch
let source_batch = RecordBatch::try_new(
Arc::clone(&source_schema),
vec![
Arc::new(StringArray::from(vec![Some("hello"), None, Some("world")])),
Arc::new(Int32Array::from(vec![1, 2, 3])),
],
)
.unwrap();

let adapted = adapter.adapt_batch(&source_batch).unwrap();

// Verify schema matches target
assert_eq!(adapted.num_columns(), 2);
assert_eq!(adapted.schema().field(0).name(), "a");
assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int64);
assert_eq!(adapted.schema().field(1).name(), "b");
assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8);

// Verify data
let col_a = adapted
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(col_a.iter().collect_vec(), vec![Some(1), Some(2), Some(3)]);

let col_b = adapted
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(
col_b.iter().collect_vec(),
vec![Some("hello"), None, Some("world")]
);
}

#[test]
fn test_batch_adapter_factory_missing_column() {
// Target schema with a column missing from source
let target_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true), // exists in source
Field::new("c", DataType::Float64, true), // missing from source
]));

let source_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true),
]));

let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
let adapter = factory.make_adapter(Arc::clone(&source_schema)).unwrap();

let source_batch = RecordBatch::try_new(
Arc::clone(&source_schema),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StringArray::from(vec!["x", "y"])),
],
)
.unwrap();

let adapted = adapter.adapt_batch(&source_batch).unwrap();

assert_eq!(adapted.num_columns(), 3);

// Missing column should be filled with nulls
let col_c = adapted.column(2);
assert_eq!(col_c.data_type(), &DataType::Float64);
assert_eq!(col_c.null_count(), 2); // All nulls
}

#[test]
fn test_batch_adapter_factory_with_struct() {
// Target has struct with Int64 id
let target_struct_fields: Fields = vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, true),
]
.into();
let target_schema = Arc::new(Schema::new(vec![Field::new(
"data",
DataType::Struct(target_struct_fields),
false,
)]));

// Source has struct with Int32 id
let source_struct_fields: Fields = vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]
.into();
let source_schema = Arc::new(Schema::new(vec![Field::new(
"data",
DataType::Struct(source_struct_fields.clone()),
false,
)]));

let struct_array = StructArray::new(
source_struct_fields,
vec![
Arc::new(Int32Array::from(vec![10, 20])) as _,
Arc::new(StringArray::from(vec!["a", "b"])) as _,
],
None,
);

let source_batch = RecordBatch::try_new(
Arc::clone(&source_schema),
vec![Arc::new(struct_array)],
)
.unwrap();

let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));
let adapter = factory.make_adapter(source_schema).unwrap();
let adapted = adapter.adapt_batch(&source_batch).unwrap();

let result_struct = adapted
.column(0)
.as_any()
.downcast_ref::<StructArray>()
.unwrap();

// Verify id was cast to Int64
let id_col = result_struct.column_by_name("id").unwrap();
assert_eq!(id_col.data_type(), &DataType::Int64);
let id_values = id_col.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(id_values.iter().collect_vec(), vec![Some(10), Some(20)]);
}

#[test]
fn test_batch_adapter_factory_identity() {
// When source and target schemas are identical, should pass through efficiently
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true),
]));

let factory = BatchAdapterFactory::new(Arc::clone(&schema));
let adapter = factory.make_adapter(Arc::clone(&schema)).unwrap();

let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
],
)
.unwrap();

let adapted = adapter.adapt_batch(&batch).unwrap();

assert_eq!(adapted.num_columns(), 2);
assert_eq!(adapted.schema().field(0).data_type(), &DataType::Int32);
assert_eq!(adapted.schema().field(1).data_type(), &DataType::Utf8);
}

#[test]
fn test_batch_adapter_factory_reuse() {
// Factory can create multiple adapters for different source schemas
let target_schema = Arc::new(Schema::new(vec![
Field::new("x", DataType::Int64, false),
Field::new("y", DataType::Utf8, true),
]));

let factory = BatchAdapterFactory::new(Arc::clone(&target_schema));

// First source schema
let source1 = Arc::new(Schema::new(vec![
Field::new("x", DataType::Int32, false),
Field::new("y", DataType::Utf8, true),
]));
let adapter1 = factory.make_adapter(source1).unwrap();

// Second source schema (different order)
let source2 = Arc::new(Schema::new(vec![
Field::new("y", DataType::Utf8, true),
Field::new("x", DataType::Int64, false),
]));
let adapter2 = factory.make_adapter(source2).unwrap();

// Both should work correctly
assert!(format!("{:?}", adapter1).contains("BatchAdapter"));
assert!(format!("{:?}", adapter2).contains("BatchAdapter"));
}
}