Skip to content

Commit 803199a

Browse files
committed
feat(datafusion): adapt IcebergProjectExec to use one partition column containing all the partitions values
1 parent d930df9 commit 803199a

File tree

1 file changed

+83
-52
lines changed
  • crates/integrations/datafusion/src/physical_plan

1 file changed

+83
-52
lines changed

crates/integrations/datafusion/src/physical_plan/project.rs

Lines changed: 83 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ use std::any::Any;
1919
use std::fmt::{Debug, Formatter};
2020
use std::sync::Arc;
2121

22-
use datafusion::arrow::array::{ArrayRef, RecordBatch};
23-
use datafusion::arrow::datatypes::{Field, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef};
22+
use datafusion::arrow::array::{ArrayRef, RecordBatch, StructArray};
23+
use datafusion::arrow::datatypes::{
24+
DataType, Field, Schema as ArrowSchema, SchemaRef as ArrowSchemaRef,
25+
};
2426
use datafusion::common::Result as DFResult;
2527
use datafusion::error::DataFusionError;
2628
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
@@ -36,16 +38,16 @@ use iceberg::spec::{PartitionSpec, Schema};
3638

3739
use crate::to_datafusion_error;
3840

39-
/// Prefix for partition column names to avoid collisions with regular columns
40-
const PARTITION_COLUMN_PREFIX: &str = "__partition_";
41+
/// Column name for the combined partition values struct
42+
const PARTITION_VALUES_COLUMN: &str = "_iceberg_partition_values";
4143

4244
/// An execution plan node that calculates partition values for Iceberg tables.
4345
///
44-
/// This execution plan takes input data from a child execution plan and adds partition columns
45-
/// based on the table's partition specification. The partition values are computed by applying
46-
/// the appropriate transforms to the source columns.
46+
/// This execution plan takes input data from a child execution plan and adds a single struct column
47+
/// containing all partition values based on the table's partition specification. The partition values
48+
/// are computed by applying the appropriate transforms to the source columns.
4749
///
48-
/// The output schema includes all original columns plus additional partition columns.
50+
/// The output schema includes all original columns plus a single `_iceberg_partition_values` struct column.
4951
#[derive(Debug, Clone)]
5052
pub(crate) struct IcebergProjectExec {
5153
input: Arc<dyn ExecutionPlan>,
@@ -56,10 +58,12 @@ pub(crate) struct IcebergProjectExec {
5658
}
5759

5860
/// IcebergProjectExec is responsible for calculating partition values for Iceberg tables.
59-
/// It takes input data from a child execution plan and adds partition columns based on the table's
60-
/// partition specification. The partition values are computed by applying the appropriate transforms
61-
/// to the source columns. The output schema includes all original columns plus additional partition
62-
/// columns.
61+
/// It takes input data from a child execution plan and adds a single struct column containing
62+
/// all partition values based on the table's partition specification. The partition values are
63+
/// computed by applying the appropriate transforms to the source columns. The output schema
64+
/// includes all original columns plus a single `_iceberg_partition_values` struct column.
65+
/// This approach simplifies downstream repartitioning operations by providing a single column
66+
/// that can be directly used for sorting and repartitioning.
6367
impl IcebergProjectExec {
6468
pub fn new(
6569
input: Arc<dyn ExecutionPlan>,
@@ -92,7 +96,7 @@ impl IcebergProjectExec {
9296
)
9397
}
9498

95-
/// Create the output schema by adding partition columns to the input schema
99+
/// Create the output schema by adding a single partition values struct column to the input schema
96100
fn create_output_schema(
97101
input_schema: &ArrowSchema,
98102
partition_spec: &PartitionSpec,
@@ -104,38 +108,51 @@ impl IcebergProjectExec {
104108

105109
let mut fields: Vec<Arc<Field>> = input_schema.fields().to_vec();
106110

107-
let partition_struct = partition_spec
111+
let partition_struct_type = partition_spec
108112
.partition_type(table_schema)
109113
.map_err(to_datafusion_error)?;
110114

111-
for (idx, pf) in partition_spec.fields().iter().enumerate() {
112-
let struct_field = partition_struct.fields().get(idx).ok_or_else(|| {
113-
DataFusionError::Internal(
114-
"Partition field index out of bounds when creating output schema".to_string(),
115-
)
116-
})?;
117-
let arrow_type = iceberg::arrow::type_to_arrow_type(&struct_field.field_type)
115+
// Convert the Iceberg struct type to Arrow struct type
116+
let arrow_struct_type =
117+
iceberg::arrow::type_to_arrow_type(&iceberg::spec::Type::Struct(partition_struct_type))
118118
.map_err(to_datafusion_error)?;
119-
let partition_column_name = Self::create_partition_column_name(&pf.name);
120-
let nullable = !struct_field.required;
121-
fields.push(Arc::new(Field::new(
122-
&partition_column_name,
123-
arrow_type,
124-
nullable,
125-
)));
126-
}
119+
120+
// Add a single struct column containing all partition values
121+
fields.push(Arc::new(Field::new(
122+
PARTITION_VALUES_COLUMN,
123+
arrow_struct_type,
124+
false, // Partition values are generally not null
125+
)));
126+
127127
Ok(Arc::new(ArrowSchema::new(fields)))
128128
}
129129

130-
/// Calculate partition values for a record batch
131-
fn calculate_partition_values(&self, batch: &RecordBatch) -> DFResult<Vec<ArrayRef>> {
130+
/// Calculate partition values for a record batch and return as a single struct array
131+
fn calculate_partition_values(&self, batch: &RecordBatch) -> DFResult<Option<ArrayRef>> {
132132
if self.partition_spec.is_unpartitioned() {
133-
return Ok(vec![]);
133+
return Ok(None);
134134
}
135135

136136
let batch_schema = batch.schema();
137137
let mut partition_values = Vec::with_capacity(self.partition_spec.fields().len());
138138

139+
// Get the expected struct fields from our output schema
140+
let partition_column_field = self
141+
.output_schema
142+
.field_with_name(PARTITION_VALUES_COLUMN)
143+
.map_err(|e| {
144+
DataFusionError::Internal(format!("Partition column not found in schema: {}", e))
145+
})?;
146+
147+
let expected_struct_fields = match partition_column_field.data_type() {
148+
DataType::Struct(fields) => fields.clone(),
149+
_ => {
150+
return Err(DataFusionError::Internal(
151+
"Partition column is not a struct type".to_string(),
152+
));
153+
}
154+
};
155+
139156
for pf in self.partition_spec.fields() {
140157
// Find the source field in the table schema
141158
let source_field = self.table_schema.field_by_id(pf.source_id).ok_or_else(|| {
@@ -158,7 +175,16 @@ impl IcebergProjectExec {
158175

159176
partition_values.push(partition_value);
160177
}
161-
Ok(partition_values)
178+
179+
// Create struct array using the expected fields from the schema
180+
let struct_array = StructArray::try_new(
181+
expected_struct_fields,
182+
partition_values,
183+
None, // No null buffer for the struct array itself
184+
)
185+
.map_err(|e| DataFusionError::ArrowError(e, None))?;
186+
187+
Ok(Some(Arc::new(struct_array)))
162188
}
163189

164190
/// Extract a column by an index path
@@ -287,21 +313,18 @@ impl IcebergProjectExec {
287313
Ok(indices)
288314
}
289315

290-
/// Apply a naming convention for partition columns using spec alias, prefixed to avoid collisions
291-
fn create_partition_column_name(partition_field_alias: &str) -> String {
292-
format!("{}{}", PARTITION_COLUMN_PREFIX, partition_field_alias)
293-
}
294-
295-
/// Process a single batch by adding partition columns
316+
/// Process a single batch by adding a partition values struct column
296317
fn process_batch(&self, batch: RecordBatch) -> DFResult<RecordBatch> {
297318
if self.partition_spec.is_unpartitioned() {
298319
return Ok(batch);
299320
}
300321

301-
let partition_arrays = self.calculate_partition_values(&batch)?;
322+
let partition_array = self.calculate_partition_values(&batch)?;
302323

303324
let mut all_columns = batch.columns().to_vec();
304-
all_columns.extend(partition_arrays);
325+
if let Some(partition_array) = partition_array {
326+
all_columns.push(partition_array);
327+
}
305328

306329
RecordBatch::try_new(Arc::clone(&self.output_schema), all_columns)
307330
.map_err(|e| DataFusionError::ArrowError(e, None))
@@ -501,11 +524,11 @@ mod tests {
501524
IcebergProjectExec::create_output_schema(&arrow_schema, &partition_spec, &table_schema)
502525
.unwrap();
503526

504-
// Should have 3 fields: original 2 + 1 partition field
527+
// Should have 3 fields: original 2 + 1 partition values struct
505528
assert_eq!(output_schema.fields().len(), 3);
506529
assert_eq!(output_schema.field(0).name(), "id");
507530
assert_eq!(output_schema.field(1).name(), "name");
508-
assert_eq!(output_schema.field(2).name(), "__partition_id_partition");
531+
assert_eq!(output_schema.field(2).name(), "_iceberg_partition_values");
509532
}
510533

511534
#[test]
@@ -548,11 +571,11 @@ mod tests {
548571
IcebergProjectExec::create_output_schema(&arrow_schema, &partition_spec, &table_schema)
549572
.unwrap();
550573

551-
// Should have 3 fields: id, address, and partition field for city
574+
// Should have 3 fields: id, address, and partition values struct
552575
assert_eq!(output_schema.fields().len(), 3);
553576
assert_eq!(output_schema.field(0).name(), "id");
554577
assert_eq!(output_schema.field(1).name(), "address");
555-
assert_eq!(output_schema.field(2).name(), "__partition_city_partition");
578+
assert_eq!(output_schema.field(2).name(), "_iceberg_partition_values");
556579
}
557580

558581
#[test]
@@ -636,26 +659,34 @@ mod tests {
636659
let result_batch = project_exec.process_batch(batch).unwrap();
637660

638661
// Verify the result
639-
assert_eq!(result_batch.num_columns(), 3); // id, address, partition
662+
assert_eq!(result_batch.num_columns(), 3); // id, address, partition_values
640663
assert_eq!(result_batch.num_rows(), 3);
641664

642665
// Verify column names
643666
assert_eq!(result_batch.schema().field(0).name(), "id");
644667
assert_eq!(result_batch.schema().field(1).name(), "address");
645668
assert_eq!(
646669
result_batch.schema().field(2).name(),
647-
"__partition_city_partition"
670+
"_iceberg_partition_values"
648671
);
649672

650-
// Verify that the partition column contains the city values extracted from the struct
673+
// Verify that the partition values struct contains the city values extracted from the struct
651674
let partition_column = result_batch.column(2);
652-
let partition_array = partition_column
675+
let partition_struct_array = partition_column
676+
.as_any()
677+
.downcast_ref::<StructArray>()
678+
.unwrap();
679+
680+
// Get the city_partition field from the struct
681+
let city_partition_array = partition_struct_array
682+
.column_by_name("city_partition")
683+
.unwrap()
653684
.as_any()
654685
.downcast_ref::<datafusion::arrow::array::StringArray>()
655686
.unwrap();
656687

657-
assert_eq!(partition_array.value(0), "New York");
658-
assert_eq!(partition_array.value(1), "Los Angeles");
659-
assert_eq!(partition_array.value(2), "Chicago");
688+
assert_eq!(city_partition_array.value(0), "New York");
689+
assert_eq!(city_partition_array.value(1), "Los Angeles");
690+
assert_eq!(city_partition_array.value(2), "Chicago");
660691
}
661692
}

0 commit comments

Comments
 (0)