Skip to content

Commit edb4719

Browse files
committed
refactor: simplify RecordBatchProjector and optimize partition calculation
Signed-off-by: Florian Valeye <[email protected]>
1 parent bdc6aae commit edb4719

File tree

3 files changed

+72
-78
lines changed

3 files changed

+72
-78
lines changed

crates/iceberg/src/arrow/record_batch_projector.rs

Lines changed: 34 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use arrow_buffer::NullBuffer;
2222
use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef};
2323
use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
2424

25+
use crate::arrow::schema::schema_to_arrow_schema;
2526
use crate::error::Result;
2627
use crate::spec::Schema as IcebergSchema;
2728
use crate::{Error, ErrorKind};
@@ -79,22 +80,21 @@ impl RecordBatchProjector {
7980
})
8081
}
8182

82-
/// Create RecordBatchProjector using Iceberg schema for field mapping.
83+
/// Create RecordBatchProjector using Iceberg schema.
8384
///
84-
/// This constructor is more flexible and works with any Arrow schema by using
85-
/// the Iceberg schema to map field names to field IDs.
85+
/// This constructor converts the Iceberg schema to Arrow schema with field ID metadata,
86+
/// then uses the standard field ID lookup for projection.
8687
///
8788
/// # Arguments
88-
/// * `original_schema` - The original Arrow schema (doesn't need field ID metadata)
89-
/// * `iceberg_schema` - The Iceberg schema for field ID mapping
89+
/// * `iceberg_schema` - The Iceberg schema for field ID mapping
9090
/// * `target_field_ids` - The field IDs to project
91-
pub fn from_iceberg_schema_mapping(
92-
original_schema: SchemaRef,
91+
pub fn from_iceberg_schema(
9392
iceberg_schema: Arc<IcebergSchema>,
9493
target_field_ids: &[i32],
9594
) -> Result<Self> {
95+
let arrow_schema_with_ids = Arc::new(schema_to_arrow_schema(&iceberg_schema)?);
96+
9697
let field_id_fetch_func = |field: &Field| -> Result<Option<i64>> {
97-
// First try to get field ID from metadata (Parquet case)
9898
if let Some(value) = field.metadata().get(PARQUET_FIELD_ID_META_KEY) {
9999
let field_id = value.parse::<i32>().map_err(|e| {
100100
Error::new(
@@ -104,49 +104,16 @@ impl RecordBatchProjector {
104104
.with_context("value", value)
105105
.with_source(e)
106106
})?;
107-
return Ok(Some(field_id as i64));
108-
}
109-
110-
// Fallback: use Iceberg schema's built-in field lookup
111-
if let Some(iceberg_field) = iceberg_schema.field_by_name(field.name()) {
112-
return Ok(Some(iceberg_field.id as i64));
113-
}
114-
115-
// Additional fallback: for nested fields, we need to search recursively
116-
fn find_field_id_in_struct(
117-
struct_type: &crate::spec::StructType,
118-
field_name: &str,
119-
) -> Option<i32> {
120-
for field in struct_type.fields() {
121-
if field.name == field_name {
122-
return Some(field.id);
123-
}
124-
if let crate::spec::Type::Struct(nested_struct) = &*field.field_type {
125-
if let Some(nested_id) = find_field_id_in_struct(nested_struct, field_name)
126-
{
127-
return Some(nested_id);
128-
}
129-
}
130-
}
131-
None
132-
}
133-
134-
// Search in nested structs
135-
for iceberg_field in iceberg_schema.as_struct().fields() {
136-
if let crate::spec::Type::Struct(struct_type) = &*iceberg_field.field_type {
137-
if let Some(nested_id) = find_field_id_in_struct(struct_type, field.name()) {
138-
return Ok(Some(nested_id as i64));
139-
}
140-
}
107+
Ok(Some(field_id as i64))
108+
} else {
109+
Ok(None)
141110
}
142-
143-
Ok(None)
144111
};
145112

146113
let searchable_field_func = |_field: &Field| -> bool { true };
147114

148115
Self::new(
149-
original_schema,
116+
arrow_schema_with_ids,
150117
target_field_ids,
151118
field_id_fetch_func,
152119
searchable_field_func,
@@ -242,6 +209,7 @@ mod test {
242209
use arrow_schema::{DataType, Field, Fields, Schema};
243210

244211
use crate::arrow::record_batch_projector::RecordBatchProjector;
212+
use crate::spec::{NestedField, PrimitiveType, Schema as IcebergSchema, Type};
245213
use crate::{Error, ErrorKind};
246214

247215
#[test]
@@ -369,4 +337,25 @@ mod test {
369337
RecordBatchProjector::new(schema.clone(), &[3], field_id_fetch_func, |_| true);
370338
assert!(projector.is_ok());
371339
}
340+
341+
#[test]
342+
fn test_from_iceberg_schema() {
343+
let iceberg_schema = IcebergSchema::builder()
344+
.with_schema_id(0)
345+
.with_fields(vec![
346+
NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
347+
NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
348+
NestedField::optional(3, "age", Type::Primitive(PrimitiveType::Int)).into(),
349+
])
350+
.build()
351+
.unwrap();
352+
353+
let projector =
354+
RecordBatchProjector::from_iceberg_schema(Arc::new(iceberg_schema), &[1, 3]).unwrap();
355+
356+
assert_eq!(projector.field_indices.len(), 2);
357+
assert_eq!(projector.projected_schema_ref().fields().len(), 2);
358+
assert_eq!(projector.projected_schema_ref().field(0).name(), "id");
359+
assert_eq!(projector.projected_schema_ref().field(1).name(), "age");
360+
}
372361
}

crates/iceberg/src/transform/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ mod truncate;
2929
mod void;
3030

3131
/// TransformFunction is a trait that defines the interface for all transform functions.
32-
pub trait TransformFunction: Send + Sync {
32+
pub trait TransformFunction: Send + Sync + std::fmt::Debug {
3333
/// transform will take an input array and transform it into a new array.
3434
/// The implementation of this function will need to check and downcast the input to specific
3535
/// type.

crates/integrations/datafusion/src/physical_plan/project.rs

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use datafusion::physical_plan::{ColumnarValue, ExecutionPlan};
3030
use iceberg::arrow::record_batch_projector::RecordBatchProjector;
3131
use iceberg::spec::{PartitionSpec, Schema};
3232
use iceberg::table::Table;
33+
use iceberg::transform::BoxedTransformFunction;
3334

3435
use crate::to_datafusion_error;
3536

@@ -126,7 +127,7 @@ impl PhysicalExpr for PartitionExpr {
126127
}
127128

128129
fn evaluate(&self, batch: &RecordBatch) -> DFResult<ColumnarValue> {
129-
let mut calculator = self
130+
let calculator = self
130131
.calculator
131132
.lock()
132133
.map_err(|e| DataFusionError::Internal(format!("Failed to lock calculator: {}", e)))?;
@@ -183,12 +184,12 @@ impl std::hash::Hash for PartitionExpr {
183184
}
184185

185186
/// Calculator for partition values in Iceberg tables
186-
#[derive(Debug, Clone)]
187+
#[derive(Debug)]
187188
struct PartitionValueCalculator {
188189
partition_spec: PartitionSpec,
189-
table_schema: Schema,
190190
partition_type: DataType,
191-
projector: Option<RecordBatchProjector>,
191+
projector: RecordBatchProjector,
192+
transform_functions: Vec<BoxedTransformFunction>,
192193
}
193194

194195
impl PartitionValueCalculator {
@@ -203,35 +204,37 @@ impl PartitionValueCalculator {
203204
));
204205
}
205206

207+
let transform_functions: Result<Vec<BoxedTransformFunction>, _> = partition_spec
208+
.fields()
209+
.iter()
210+
.map(|pf| iceberg::transform::create_transform_function(&pf.transform))
211+
.collect();
212+
213+
let transform_functions = transform_functions.map_err(to_datafusion_error)?;
214+
215+
let source_field_ids: Vec<i32> = partition_spec
216+
.fields()
217+
.iter()
218+
.map(|pf| pf.source_id)
219+
.collect();
220+
221+
let projector = RecordBatchProjector::from_iceberg_schema(
222+
Arc::new(table_schema.clone()),
223+
&source_field_ids,
224+
)
225+
.map_err(to_datafusion_error)?;
226+
206227
Ok(Self {
207228
partition_spec,
208-
table_schema,
209229
partition_type,
210-
projector: None,
230+
projector,
231+
transform_functions,
211232
})
212233
}
213234

214-
fn calculate(&mut self, batch: &RecordBatch) -> DFResult<ArrayRef> {
215-
if self.projector.is_none() {
216-
let source_field_ids: Vec<i32> = self
217-
.partition_spec
218-
.fields()
219-
.iter()
220-
.map(|pf| pf.source_id)
221-
.collect();
222-
223-
let projector = RecordBatchProjector::from_iceberg_schema_mapping(
224-
batch.schema(),
225-
Arc::new(self.table_schema.clone()),
226-
&source_field_ids,
227-
)
228-
.map_err(to_datafusion_error)?;
229-
230-
self.projector = Some(projector);
231-
}
232-
233-
let projector = self.projector.as_ref().unwrap();
234-
let source_columns = projector
235+
fn calculate(&self, batch: &RecordBatch) -> DFResult<ArrayRef> {
236+
let source_columns = self
237+
.projector
235238
.project_column(batch.columns())
236239
.map_err(to_datafusion_error)?;
237240

@@ -246,10 +249,7 @@ impl PartitionValueCalculator {
246249

247250
let mut partition_values = Vec::with_capacity(self.partition_spec.fields().len());
248251

249-
for (source_column, pf) in source_columns.iter().zip(self.partition_spec.fields()) {
250-
let transform_fn = iceberg::transform::create_transform_function(&pf.transform)
251-
.map_err(to_datafusion_error)?;
252-
252+
for (source_column, transform_fn) in source_columns.iter().zip(&self.transform_functions) {
253253
let partition_value = transform_fn
254254
.transform(source_column.clone())
255255
.map_err(to_datafusion_error)?;
@@ -302,6 +302,11 @@ mod tests {
302302
.build()
303303
.unwrap();
304304

305+
let arrow_schema = Arc::new(ArrowSchema::new(vec![
306+
Field::new("id", DataType::Int32, false),
307+
Field::new("name", DataType::Utf8, false),
308+
]));
309+
305310
let partition_type = build_partition_type(&partition_spec, &table_schema).unwrap();
306311
let calculator = PartitionValueCalculator::new(
307312
partition_spec.clone(),
@@ -476,7 +481,7 @@ mod tests {
476481
.unwrap();
477482

478483
let partition_type = build_partition_type(&partition_spec, &table_schema).unwrap();
479-
let mut calculator =
484+
let calculator =
480485
PartitionValueCalculator::new(partition_spec, table_schema, partition_type).unwrap();
481486
let array = calculator.calculate(&batch).unwrap();
482487

0 commit comments

Comments
 (0)