Skip to content

Commit f157155

Browse files
authored
feat(rust/sedona-expr): Use Covers filter for ST_Equals for more Geoparquet pruning (#216)
1 parent bb526b7 commit f157155

File tree

1 file changed

+62
-10
lines changed

1 file changed

+62
-10
lines changed

rust/sedona-expr/src/spatial_filter.rs

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ impl SpatialFilter {
177177
let args = parse_args(raw_args);
178178
let fun_name = scalar_fun.fun().name();
179179
match fun_name {
180-
"st_intersects" | "st_equals" | "st_touches" | "st_crosses" | "st_overlaps" => {
180+
"st_intersects" | "st_touches" | "st_crosses" | "st_overlaps" => {
181181
if args.len() != 2 {
182182
return sedona_internal_err!("unexpected argument count in filter evaluation");
183183
}
@@ -199,6 +199,28 @@ impl SpatialFilter {
199199
_ => Ok(Some(Self::Unknown)),
200200
}
201201
}
202+
"st_equals" => {
203+
if args.len() != 2 {
204+
return sedona_internal_err!("unexpected argument count in filter evaluation");
205+
}
206+
207+
match (&args[0], &args[1]) {
208+
(ArgRef::Col(column), ArgRef::Lit(literal))
209+
| (ArgRef::Lit(literal), ArgRef::Col(column)) => {
210+
if !is_prunable_geospatial_literal(literal) {
211+
return Ok(Some(Self::Unknown));
212+
}
213+
match literal_bounds(literal) {
214+
Ok(literal_bounds) => {
215+
Ok(Some(Self::Covers(column.clone(), literal_bounds)))
216+
}
217+
Err(e) => Err(DataFusionError::External(Box::new(e))),
218+
}
219+
}
220+
// Not between a literal and a column
221+
_ => Ok(Some(Self::Unknown)),
222+
}
223+
}
202224
"st_within" | "st_covered_by" | "st_coveredby" => {
203225
if args.len() != 2 {
204226
return sedona_internal_err!("unexpected argument count in filter evaluation");
@@ -575,15 +597,8 @@ mod test {
575597
}
576598

577599
#[rstest]
578-
fn predicate_from_expr_commutative_functions(
579-
#[values(
580-
"st_intersects",
581-
"st_equals",
582-
"st_touches",
583-
"st_crosses",
584-
"st_overlaps"
585-
)]
586-
func_name: &str,
600+
fn predicate_from_expr_commutative_intersects_functions(
601+
#[values("st_intersects", "st_touches", "st_crosses", "st_overlaps")] func_name: &str,
587602
) {
588603
let column: Arc<dyn PhysicalExpr> = Arc::new(Column::new("geometry", 0));
589604
let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap();
@@ -620,6 +635,43 @@ mod test {
620635
);
621636
}
622637

638+
#[rstest]
639+
fn predicate_from_expr_equals_function(#[values("st_equals")] func_name: &str) {
640+
let column: Arc<dyn PhysicalExpr> = Arc::new(Column::new("geometry", 0));
641+
let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap();
642+
let literal: Arc<dyn PhysicalExpr> = Arc::new(Literal::new_with_metadata(
643+
create_scalar(Some("POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))"), &WKB_GEOMETRY),
644+
Some(storage_field.metadata().into()),
645+
));
646+
647+
// Test functions that should result in Covers filter
648+
let func = create_dummy_spatial_function(func_name, 2);
649+
let expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
650+
func_name,
651+
Arc::new(func.clone()),
652+
vec![column.clone(), literal.clone()],
653+
Arc::new(Field::new("", DataType::Boolean, true)),
654+
));
655+
let predicate = SpatialFilter::try_from_expr(&expr).unwrap();
656+
assert!(
657+
matches!(predicate, SpatialFilter::Covers(_, _)),
658+
"Function {func_name} should produce Covers filter"
659+
);
660+
661+
// Test reversed argument order
662+
let expr_reversed: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
663+
func_name,
664+
Arc::new(func),
665+
vec![literal.clone(), column.clone()],
666+
Arc::new(Field::new("", DataType::Boolean, true)),
667+
));
668+
let predicate_reversed = SpatialFilter::try_from_expr(&expr_reversed).unwrap();
669+
assert!(
670+
matches!(predicate_reversed, SpatialFilter::Covers(_, _)),
671+
"Function {func_name} with reversed args should produce Covers filter"
672+
);
673+
}
674+
623675
#[rstest]
624676
fn predicate_from_expr_within_covered_by_functions(
625677
#[values("st_within", "st_covered_by", "st_coveredby")] func_name: &str,

0 commit comments

Comments
 (0)