Skip to content

Commit 7092f64

Browse files
committed
feat: convert partition filters to kernel predicates
Signed-off-by: Robert Pack <[email protected]>
1 parent d96eab3 commit 7092f64

File tree

1 file changed

+314
-4
lines changed

1 file changed

+314
-4
lines changed

crates/core/src/kernel/schema/partitions.rs

Lines changed: 314 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ use std::cmp::Ordering;
33
use std::collections::HashMap;
44
use std::convert::TryFrom;
55

6-
use delta_kernel::expressions::Scalar;
6+
use delta_kernel::expressions::{Expression, JunctionPredicateOp, Predicate, Scalar};
7+
use delta_kernel::schema::StructType;
78
use serde::{Serialize, Serializer};
89

910
use super::{DataType, PrimitiveType};
10-
use crate::errors::DeltaTableError;
11+
use crate::errors::{DeltaResult, DeltaTableError};
1112
use crate::kernel::scalars::ScalarExt;
1213

1314
/// A special value used in Hive to represent the null partition in partitioned tables
@@ -99,7 +100,11 @@ fn compare_typed_value(
99100
/// Partition filters methods for filtering the DeltaTable partitions.
100101
impl PartitionFilter {
101102
/// Indicates if a DeltaTable partition matches with the partition filter by key and value.
102-
pub fn match_partition(&self, partition: &DeltaTablePartition, data_type: &DataType) -> bool {
103+
pub(crate) fn match_partition(
104+
&self,
105+
partition: &DeltaTablePartition,
106+
data_type: &DataType,
107+
) -> bool {
103108
if self.key != partition.key {
104109
return false;
105110
}
@@ -153,7 +158,7 @@ impl PartitionFilter {
153158

154159
/// Indicates if one of the DeltaTable partition among the list
155160
/// matches with the partition filter.
156-
pub fn match_partitions(
161+
pub(crate) fn match_partitions(
157162
&self,
158163
partitions: &[DeltaTablePartition],
159164
partition_col_data_types: &HashMap<&String, &DataType>,
@@ -308,9 +313,65 @@ impl TryFrom<&str> for DeltaTablePartition {
308313
}
309314
}
310315

316+
#[allow(unused)] // TODO: remove once we use this in kernel log replay
317+
pub(crate) fn to_kernel_predicate(
318+
filters: &[PartitionFilter],
319+
table_schema: &StructType,
320+
) -> DeltaResult<Predicate> {
321+
let predicates = filters
322+
.iter()
323+
.map(|filter| filter_to_kernel_predicate(filter, table_schema))
324+
.collect::<DeltaResult<Vec<_>>>()?;
325+
Ok(Predicate::junction(JunctionPredicateOp::And, predicates))
326+
}
327+
328+
fn filter_to_kernel_predicate(
329+
filter: &PartitionFilter,
330+
table_schema: &StructType,
331+
) -> DeltaResult<Predicate> {
332+
let Some(field) = table_schema.field(&filter.key) else {
333+
return Err(DeltaTableError::SchemaMismatch {
334+
msg: format!("Field '{}' is not a root table field.", filter.key),
335+
});
336+
};
337+
let Some(dt) = field.data_type().as_primitive_opt() else {
338+
return Err(DeltaTableError::SchemaMismatch {
339+
msg: format!("Field '{}' is not a primitive type", field.name()),
340+
});
341+
};
342+
343+
let column = Expression::column([field.name()]);
344+
Ok(match &filter.value {
345+
PartitionValue::Equal(raw) => column.eq(dt.parse_scalar(raw)?),
346+
PartitionValue::NotEqual(raw) => column.ne(dt.parse_scalar(raw)?),
347+
PartitionValue::LessThan(raw) => column.lt(dt.parse_scalar(raw)?),
348+
PartitionValue::LessThanOrEqual(raw) => column.le(dt.parse_scalar(raw)?),
349+
PartitionValue::GreaterThan(raw) => column.gt(dt.parse_scalar(raw)?),
350+
PartitionValue::GreaterThanOrEqual(raw) => column.ge(dt.parse_scalar(raw)?),
351+
op @ PartitionValue::In(raw_values) | op @ PartitionValue::NotIn(raw_values) => {
352+
let values = raw_values
353+
.iter()
354+
.map(|v| dt.parse_scalar(v))
355+
.collect::<Result<Vec<_>, _>>()?;
356+
let (expr, operator): (Box<dyn Fn(Scalar) -> Predicate>, _) = match op {
357+
PartitionValue::In(_) => {
358+
(Box::new(|v| column.clone().eq(v)), JunctionPredicateOp::Or)
359+
}
360+
PartitionValue::NotIn(_) => {
361+
(Box::new(|v| column.clone().ne(v)), JunctionPredicateOp::And)
362+
}
363+
_ => unreachable!(),
364+
};
365+
let predicates = values.into_iter().map(expr).collect::<Vec<_>>();
366+
Predicate::junction(operator, predicates)
367+
}
368+
})
369+
}
370+
311371
#[cfg(test)]
312372
mod tests {
313373
use super::*;
374+
use crate::kernel::StructField;
314375
use serde_json::json;
315376

316377
fn check_json_serialize(filter: PartitionFilter, expected_json: &str) {
@@ -487,4 +548,253 @@ mod tests {
487548
assert!(valid_filter_month.match_partitions(&partitions, &partition_data_types),);
488549
assert!(!invalid_filter.match_partitions(&partitions, &partition_data_types),);
489550
}
551+
552+
#[test]
553+
fn test_filter_to_kernel_predicate_equal() {
554+
let schema = StructType::new(vec![
555+
StructField::new("name", DataType::Primitive(PrimitiveType::String), true),
556+
StructField::new("age", DataType::Primitive(PrimitiveType::Integer), true),
557+
]);
558+
let filter = PartitionFilter {
559+
key: "name".to_string(),
560+
value: PartitionValue::Equal("Alice".to_string()),
561+
};
562+
563+
let predicate = filter_to_kernel_predicate(&filter, &schema).unwrap();
564+
565+
let expected = Expression::column(["name"]).eq(Scalar::String("Alice".into()));
566+
assert_eq!(predicate, expected);
567+
}
568+
569+
#[test]
570+
fn test_filter_to_kernel_predicate_not_equal() {
571+
let schema = StructType::new(vec![StructField::new(
572+
"status",
573+
DataType::Primitive(PrimitiveType::String),
574+
true,
575+
)]);
576+
let filter = PartitionFilter {
577+
key: "status".to_string(),
578+
value: PartitionValue::NotEqual("inactive".to_string()),
579+
};
580+
581+
let predicate = filter_to_kernel_predicate(&filter, &schema).unwrap();
582+
583+
let expected = Expression::column(["status"]).ne(Scalar::String("inactive".into()));
584+
assert_eq!(predicate, expected);
585+
}
586+
587+
#[test]
588+
fn test_filter_to_kernel_predicate_comparisons() {
589+
let schema = StructType::new(vec![
590+
StructField::new("score", DataType::Primitive(PrimitiveType::Integer), true),
591+
StructField::new("price", DataType::Primitive(PrimitiveType::Long), true),
592+
]);
593+
594+
// Test less than
595+
let filter = PartitionFilter {
596+
key: "score".to_string(),
597+
value: PartitionValue::LessThan("100".to_string()),
598+
};
599+
let predicate = filter_to_kernel_predicate(&filter, &schema).unwrap();
600+
let expected = Expression::column(["score"]).lt(Scalar::Integer(100));
601+
assert_eq!(predicate, expected);
602+
603+
// Test less than or equal
604+
let filter = PartitionFilter {
605+
key: "score".to_string(),
606+
value: PartitionValue::LessThanOrEqual("100".to_string()),
607+
};
608+
let predicate = filter_to_kernel_predicate(&filter, &schema).unwrap();
609+
let expected = Expression::column(["score"]).le(Scalar::Integer(100));
610+
assert_eq!(predicate, expected);
611+
612+
// Test greater than
613+
let filter = PartitionFilter {
614+
key: "price".to_string(),
615+
value: PartitionValue::GreaterThan("50".to_string()),
616+
};
617+
let predicate = filter_to_kernel_predicate(&filter, &schema).unwrap();
618+
let expected = Expression::column(["price"]).gt(Scalar::Long(50));
619+
assert_eq!(predicate, expected);
620+
621+
// Test greater than or equal
622+
let filter = PartitionFilter {
623+
key: "price".to_string(),
624+
value: PartitionValue::GreaterThanOrEqual("50".to_string()),
625+
};
626+
let predicate = filter_to_kernel_predicate(&filter, &schema).unwrap();
627+
let expected = Expression::column(["price"]).ge(Scalar::Long(50));
628+
assert_eq!(predicate, expected);
629+
}
630+
631+
#[test]
632+
fn test_filter_to_kernel_predicate_in_operations() {
633+
let schema = StructType::new(vec![StructField::new(
634+
"category",
635+
DataType::Primitive(PrimitiveType::String),
636+
true,
637+
)]);
638+
639+
let column = Expression::column(["category"]);
640+
let categories = [
641+
Scalar::String("books".to_string()),
642+
Scalar::String("electronics".to_string()),
643+
];
644+
645+
// Test In operation
646+
let filter = PartitionFilter {
647+
key: "category".to_string(),
648+
value: PartitionValue::In(vec!["books".to_string(), "electronics".to_string()]),
649+
};
650+
let predicate = filter_to_kernel_predicate(&filter, &schema).unwrap();
651+
let expected_inner = categories
652+
.clone()
653+
.into_iter()
654+
.map(|s| column.clone().eq(s))
655+
.collect::<Vec<_>>();
656+
let expected = Predicate::junction(JunctionPredicateOp::Or, expected_inner);
657+
assert_eq!(predicate, expected);
658+
659+
// Test NotIn operation
660+
let filter = PartitionFilter {
661+
key: "category".to_string(),
662+
value: PartitionValue::NotIn(vec!["books".to_string(), "electronics".to_string()]),
663+
};
664+
let predicate = filter_to_kernel_predicate(&filter, &schema).unwrap();
665+
let expected_inner = categories
666+
.into_iter()
667+
.map(|s| column.clone().ne(s))
668+
.collect::<Vec<_>>();
669+
let expected = Predicate::junction(JunctionPredicateOp::And, expected_inner);
670+
assert_eq!(predicate, expected);
671+
}
672+
673+
#[test]
674+
fn test_filter_to_kernel_predicate_empty_in_list() {
675+
let schema = StructType::new(vec![StructField::new(
676+
"tag",
677+
DataType::Primitive(PrimitiveType::String),
678+
true,
679+
)]);
680+
681+
let filter = PartitionFilter {
682+
key: "tag".to_string(),
683+
value: PartitionValue::In(vec![]),
684+
};
685+
let result = filter_to_kernel_predicate(&filter, &schema);
686+
assert!(result.is_ok());
687+
}
688+
689+
#[test]
690+
fn test_filter_to_kernel_predicate_field_not_found() {
691+
let schema = StructType::new(vec![StructField::new(
692+
"existing_field",
693+
DataType::Primitive(PrimitiveType::String),
694+
true,
695+
)]);
696+
697+
let filter = PartitionFilter {
698+
key: "nonexistent_field".to_string(),
699+
value: PartitionValue::Equal("value".to_string()),
700+
};
701+
702+
let result = filter_to_kernel_predicate(&filter, &schema);
703+
assert!(result.is_err());
704+
assert!(matches!(
705+
result.unwrap_err(),
706+
DeltaTableError::SchemaMismatch { .. }
707+
));
708+
}
709+
710+
#[test]
711+
fn test_filter_to_kernel_predicate_non_primitive_field() {
712+
let nested_struct = StructType::new(vec![StructField::new(
713+
"inner",
714+
DataType::Primitive(PrimitiveType::String),
715+
true,
716+
)]);
717+
let schema = StructType::new(vec![StructField::new(
718+
"nested",
719+
DataType::Struct(Box::new(nested_struct)),
720+
true,
721+
)]);
722+
723+
let filter = PartitionFilter {
724+
key: "nested".to_string(),
725+
value: PartitionValue::Equal("value".to_string()),
726+
};
727+
728+
let result = filter_to_kernel_predicate(&filter, &schema);
729+
assert!(result.is_err());
730+
assert!(matches!(
731+
result.unwrap_err(),
732+
DeltaTableError::SchemaMismatch { .. }
733+
));
734+
}
735+
736+
#[test]
737+
fn test_filter_to_kernel_predicate_different_data_types() {
738+
let schema = StructType::new(vec![
739+
StructField::new(
740+
"bool_field",
741+
DataType::Primitive(PrimitiveType::Boolean),
742+
true,
743+
),
744+
StructField::new("date_field", DataType::Primitive(PrimitiveType::Date), true),
745+
StructField::new(
746+
"timestamp_field",
747+
DataType::Primitive(PrimitiveType::Timestamp),
748+
true,
749+
),
750+
StructField::new(
751+
"double_field",
752+
DataType::Primitive(PrimitiveType::Double),
753+
true,
754+
),
755+
StructField::new(
756+
"float_field",
757+
DataType::Primitive(PrimitiveType::Float),
758+
true,
759+
),
760+
]);
761+
762+
// Test boolean field
763+
let filter = PartitionFilter {
764+
key: "bool_field".to_string(),
765+
value: PartitionValue::Equal("true".to_string()),
766+
};
767+
assert!(filter_to_kernel_predicate(&filter, &schema).is_ok());
768+
769+
// Test date field
770+
let filter = PartitionFilter {
771+
key: "date_field".to_string(),
772+
value: PartitionValue::GreaterThan("2023-01-01".to_string()),
773+
};
774+
assert!(filter_to_kernel_predicate(&filter, &schema).is_ok());
775+
776+
// Test float field
777+
let filter = PartitionFilter {
778+
key: "float_field".to_string(),
779+
value: PartitionValue::LessThan("3.14".to_string()),
780+
};
781+
assert!(filter_to_kernel_predicate(&filter, &schema).is_ok());
782+
}
783+
784+
#[test]
785+
fn test_filter_to_kernel_predicate_invalid_scalar_value() {
786+
let schema = StructType::new(vec![StructField::new(
787+
"number",
788+
DataType::Primitive(PrimitiveType::Integer),
789+
true,
790+
)]);
791+
792+
let filter = PartitionFilter {
793+
key: "number".to_string(),
794+
value: PartitionValue::Equal("not_a_number".to_string()),
795+
};
796+
797+
let result = filter_to_kernel_predicate(&filter, &schema);
798+
assert!(result.is_err());
799+
}
490800
}

0 commit comments

Comments
 (0)