Skip to content

Commit 5724fc5

Browse files
authored
feat(datafusion): Add sort_by_partition to sort the input partitioned data (#1618)
1 parent a371d82 commit 5724fc5

File tree

2 files changed

+245
-0
lines changed

2 files changed

+245
-0
lines changed

crates/integrations/datafusion/src/physical_plan/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub(crate) mod metadata_scan;
2121
pub(crate) mod project;
2222
pub(crate) mod repartition;
2323
pub(crate) mod scan;
24+
pub(crate) mod sort;
2425
pub(crate) mod write;
2526

2627
pub(crate) const DATA_FILES_COL_NAME: &str = "data_files";
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Partition-based sorting for Iceberg tables.
19+
20+
use std::sync::Arc;
21+
22+
use datafusion::arrow::compute::SortOptions;
23+
use datafusion::common::Result as DFResult;
24+
use datafusion::error::DataFusionError;
25+
use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr};
26+
use datafusion::physical_plan::ExecutionPlan;
27+
use datafusion::physical_plan::expressions::Column;
28+
use datafusion::physical_plan::sorts::sort::SortExec;
29+
use iceberg::arrow::PROJECTED_PARTITION_VALUE_COLUMN;
30+
31+
/// Sorts an ExecutionPlan by partition values for Iceberg tables.
32+
///
33+
/// This function takes an input ExecutionPlan that has been extended with partition values
34+
/// (via `project_with_partition`) and returns a SortExec that sorts by the partition column.
35+
/// The partition values are expected to be in a struct column named `PROJECTED_PARTITION_VALUE_COLUMN`.
36+
///
37+
/// For unpartitioned tables or plans without the partition column, returns an error.
38+
///
39+
/// # Arguments
40+
/// * `input` - The input ExecutionPlan with projected partition values
41+
///
42+
/// # Returns
43+
/// * `Ok(Arc<dyn ExecutionPlan>)` - A SortExec that sorts by partition values
44+
/// * `Err` - If the partition column is not found
45+
///
46+
/// TODO remove dead_code mark when integrating with insert_into
47+
#[allow(dead_code)]
48+
pub(crate) fn sort_by_partition(input: Arc<dyn ExecutionPlan>) -> DFResult<Arc<dyn ExecutionPlan>> {
49+
let schema = input.schema();
50+
51+
// Find the partition column in the schema
52+
let (partition_column_index, _partition_field) = schema
53+
.column_with_name(PROJECTED_PARTITION_VALUE_COLUMN)
54+
.ok_or_else(|| {
55+
DataFusionError::Plan(format!(
56+
"Partition column '{}' not found in schema. Ensure the plan has been extended with partition values using project_with_partition.",
57+
PROJECTED_PARTITION_VALUE_COLUMN
58+
))
59+
})?;
60+
61+
// Create a single sort expression for the partition column
62+
let column_expr = Arc::new(Column::new(
63+
PROJECTED_PARTITION_VALUE_COLUMN,
64+
partition_column_index,
65+
));
66+
67+
let sort_expr = PhysicalSortExpr {
68+
expr: column_expr,
69+
options: SortOptions::default(), // Ascending, nulls last
70+
};
71+
72+
// Create a SortExec with preserve_partitioning=true to ensure the output partitioning
73+
// is the same as the input partitioning, and the data is sorted within each partition
74+
let lex_ordering = LexOrdering::new(vec![sort_expr]).ok_or_else(|| {
75+
DataFusionError::Plan("Failed to create LexOrdering from sort expression".to_string())
76+
})?;
77+
78+
let sort_exec = SortExec::new(lex_ordering, input).with_preserve_partitioning(true);
79+
80+
Ok(Arc::new(sort_exec))
81+
}
82+
83+
#[cfg(test)]
84+
mod tests {
85+
use datafusion::arrow::array::{Int32Array, RecordBatch, StringArray, StructArray};
86+
use datafusion::arrow::datatypes::{DataType, Field, Fields, Schema as ArrowSchema};
87+
use datafusion::datasource::{MemTable, TableProvider};
88+
use datafusion::prelude::SessionContext;
89+
90+
use super::*;
91+
92+
#[tokio::test]
93+
async fn test_sort_by_partition_basic() {
94+
// Create a schema with a partition column
95+
let partition_fields =
96+
Fields::from(vec![Field::new("id_partition", DataType::Int32, false)]);
97+
98+
let schema = Arc::new(ArrowSchema::new(vec![
99+
Field::new("id", DataType::Int32, false),
100+
Field::new("name", DataType::Utf8, false),
101+
Field::new(
102+
PROJECTED_PARTITION_VALUE_COLUMN,
103+
DataType::Struct(partition_fields.clone()),
104+
false,
105+
),
106+
]));
107+
108+
// Create test data with partition values
109+
let id_array = Arc::new(Int32Array::from(vec![3, 1, 2]));
110+
let name_array = Arc::new(StringArray::from(vec!["c", "a", "b"]));
111+
let partition_array = Arc::new(StructArray::from(vec![(
112+
Arc::new(Field::new("id_partition", DataType::Int32, false)),
113+
Arc::new(Int32Array::from(vec![3, 1, 2])) as _,
114+
)]));
115+
116+
let batch =
117+
RecordBatch::try_new(schema.clone(), vec![id_array, name_array, partition_array])
118+
.unwrap();
119+
120+
let ctx = SessionContext::new();
121+
let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap();
122+
let input = mem_table.scan(&ctx.state(), None, &[], None).await.unwrap();
123+
124+
// Apply sort
125+
let sorted_plan = sort_by_partition(input).unwrap();
126+
127+
// Execute and verify
128+
let result = datafusion::physical_plan::collect(sorted_plan, ctx.task_ctx())
129+
.await
130+
.unwrap();
131+
132+
assert_eq!(result.len(), 1);
133+
let result_batch = &result[0];
134+
135+
let id_col = result_batch
136+
.column(0)
137+
.as_any()
138+
.downcast_ref::<Int32Array>()
139+
.unwrap();
140+
141+
// Verify data is sorted by partition value
142+
assert_eq!(id_col.value(0), 1);
143+
assert_eq!(id_col.value(1), 2);
144+
assert_eq!(id_col.value(2), 3);
145+
}
146+
147+
#[tokio::test]
148+
async fn test_sort_by_partition_missing_column() {
149+
let schema = Arc::new(ArrowSchema::new(vec![
150+
Field::new("id", DataType::Int32, false),
151+
Field::new("name", DataType::Utf8, false),
152+
]));
153+
154+
let batch = RecordBatch::try_new(schema.clone(), vec![
155+
Arc::new(Int32Array::from(vec![1, 2, 3])),
156+
Arc::new(StringArray::from(vec!["a", "b", "c"])),
157+
])
158+
.unwrap();
159+
160+
let ctx = SessionContext::new();
161+
let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap();
162+
let input = mem_table.scan(&ctx.state(), None, &[], None).await.unwrap();
163+
164+
let result = sort_by_partition(input);
165+
assert!(result.is_err());
166+
assert!(
167+
result
168+
.unwrap_err()
169+
.to_string()
170+
.contains("Partition column '_partition' not found")
171+
);
172+
}
173+
174+
#[tokio::test]
175+
async fn test_sort_by_partition_multi_field() {
176+
// Test with multiple partition fields in the struct
177+
let partition_fields = Fields::from(vec![
178+
Field::new("year", DataType::Int32, false),
179+
Field::new("month", DataType::Int32, false),
180+
]);
181+
182+
let schema = Arc::new(ArrowSchema::new(vec![
183+
Field::new("id", DataType::Int32, false),
184+
Field::new("data", DataType::Utf8, false),
185+
Field::new(
186+
PROJECTED_PARTITION_VALUE_COLUMN,
187+
DataType::Struct(partition_fields.clone()),
188+
false,
189+
),
190+
]));
191+
192+
// Create test data with partition values (year, month)
193+
let id_array = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
194+
let data_array = Arc::new(StringArray::from(vec!["a", "b", "c", "d"]));
195+
196+
// Partition values: (2024, 2), (2024, 1), (2023, 12), (2024, 1)
197+
let year_array = Arc::new(Int32Array::from(vec![2024, 2024, 2023, 2024]));
198+
let month_array = Arc::new(Int32Array::from(vec![2, 1, 12, 1]));
199+
200+
let partition_array = Arc::new(StructArray::from(vec![
201+
(
202+
Arc::new(Field::new("year", DataType::Int32, false)),
203+
year_array as _,
204+
),
205+
(
206+
Arc::new(Field::new("month", DataType::Int32, false)),
207+
month_array as _,
208+
),
209+
]));
210+
211+
let batch =
212+
RecordBatch::try_new(schema.clone(), vec![id_array, data_array, partition_array])
213+
.unwrap();
214+
215+
let ctx = SessionContext::new();
216+
let mem_table = MemTable::try_new(schema.clone(), vec![vec![batch]]).unwrap();
217+
let input = mem_table.scan(&ctx.state(), None, &[], None).await.unwrap();
218+
219+
// Apply sort
220+
let sorted_plan = sort_by_partition(input).unwrap();
221+
222+
// Execute and verify
223+
let result = datafusion::physical_plan::collect(sorted_plan, ctx.task_ctx())
224+
.await
225+
.unwrap();
226+
227+
assert_eq!(result.len(), 1);
228+
let result_batch = &result[0];
229+
230+
let id_col = result_batch
231+
.column(0)
232+
.as_any()
233+
.downcast_ref::<Int32Array>()
234+
.unwrap();
235+
236+
// Verify data is sorted by partition value (struct comparison)
237+
// Expected order: (2023, 12), (2024, 1), (2024, 1), (2024, 2)
238+
// Which corresponds to ids: 3, 2, 4, 1
239+
assert_eq!(id_col.value(0), 3);
240+
assert_eq!(id_col.value(1), 2);
241+
assert_eq!(id_col.value(2), 4);
242+
assert_eq!(id_col.value(3), 1);
243+
}
244+
}

0 commit comments

Comments
 (0)