Skip to content

Commit 9efa952

Browse files
authored
ST_Distance and ST_DWithin based on georust/geo (#73)
1 parent aa4f80a commit 9efa952

File tree

7 files changed

+376
-51
lines changed

7 files changed

+376
-51
lines changed

Cargo.lock

Lines changed: 31 additions & 31 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

c/sedona-geos/src/st_dwithin.rs

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use std::sync::Arc;
1818

1919
use arrow_array::builder::BooleanBuilder;
2020
use arrow_schema::DataType;
21-
use datafusion_common::{error::Result, DataFusionError};
21+
use datafusion_common::{cast::as_float64_array, error::Result, DataFusionError};
2222
use datafusion_expr::ColumnarValue;
2323
use geos::Geom;
2424
use sedona_expr::scalar_udf::{ScalarKernelRef, SedonaScalarKernel};
@@ -53,26 +53,14 @@ impl SedonaScalarKernel for STDWithin {
5353
arg_types: &[SedonaType],
5454
args: &[ColumnarValue],
5555
) -> Result<ColumnarValue> {
56-
// Extract the constant scalar value before looping over the input geometries
57-
let distance: Option<f64>;
5856
let arg2 = args[2].cast_to(&DataType::Float64, None)?;
59-
if let ColumnarValue::Scalar(scalar_arg) = &arg2 {
60-
if scalar_arg.is_null() {
61-
distance = None;
62-
} else {
63-
distance = Some(f64::try_from(scalar_arg.clone())?);
64-
}
65-
} else {
66-
return Err(DataFusionError::Execution(format!(
67-
"Invalid distance: {:?}",
68-
args[2]
69-
)));
70-
}
71-
7257
let executor = GeosExecutor::new(arg_types, args);
58+
let arg2_array = arg2.to_array(executor.num_iterations())?;
59+
let arg2_f64_array = as_float64_array(&arg2_array)?;
60+
let mut arg2_iter = arg2_f64_array.iter();
7361
let mut builder = BooleanBuilder::with_capacity(executor.num_iterations());
7462
executor.execute_wkb_wkb_void(|lhs, rhs| {
75-
match (lhs, rhs, distance) {
63+
match (lhs, rhs, arg2_iter.next().unwrap()) {
7664
(Some(lhs), Some(rhs), Some(distance)) => {
7765
builder.append_value(invoke_scalar(lhs, rhs, distance)?);
7866
}
@@ -151,9 +139,16 @@ mod tests {
151139
let expected: ArrayRef = arrow_array!(Boolean, [Some(true), Some(false), None, Some(true)]);
152140
assert_array_equal(
153141
&tester
154-
.invoke_array_array_scalar(arg1, arg2, distance)
142+
.invoke_array_array_scalar(Arc::clone(&arg1), Arc::clone(&arg2), distance)
155143
.unwrap(),
156144
&expected,
157145
);
146+
147+
let distance = arrow_array!(Int32, [Some(1), Some(1), Some(1), Some(1)]);
148+
let expected: ArrayRef = arrow_array!(Boolean, [Some(true), Some(false), None, Some(true)]);
149+
assert_array_equal(
150+
&tester.invoke_arrays(vec![arg1, arg2, distance]).unwrap(),
151+
&expected,
152+
);
158153
}
159154
}

rust/sedona-geo/benches/geo-functions.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,30 @@ fn criterion_benchmark(c: &mut Criterion) {
4141
"st_intersects",
4242
ArrayScalar(Point, Polygon(500)),
4343
);
44+
45+
benchmark::scalar(c, &f, "geo", "st_distance", ArrayScalar(Point, Polygon(10)));
46+
benchmark::scalar(
47+
c,
48+
&f,
49+
"geo",
50+
"st_distance",
51+
ArrayScalar(Point, Polygon(500)),
52+
);
53+
54+
benchmark::scalar(
55+
c,
56+
&f,
57+
"geo",
58+
"st_dwithin",
59+
ArrayArrayScalar(Polygon(10), Polygon(10), Float64(1.0, 2.0)),
60+
);
61+
benchmark::scalar(
62+
c,
63+
&f,
64+
"geo",
65+
"st_dwithin",
66+
ArrayArrayScalar(Polygon(10), Polygon(500), Float64(1.0, 2.0)),
67+
);
4468
}
4569

4670
fn criterion_benchmark_aggr(c: &mut Criterion) {

rust/sedona-geo/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ pub mod centroid;
1818
pub mod register;
1919
mod st_area;
2020
mod st_centroid;
21+
mod st_distance;
22+
mod st_dwithin;
2123
mod st_intersection_aggr;
2224
mod st_intersects;
2325
mod st_length;

rust/sedona-geo/src/register.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@ use crate::st_intersection_aggr::st_intersection_aggr_impl;
2121
use crate::st_line_interpolate_point::st_line_interpolate_point_impl;
2222
use crate::st_union_aggr::st_union_aggr_impl;
2323
use crate::{
24-
st_area::st_area_impl, st_centroid::st_centroid_impl, st_intersects::st_intersects_impl,
25-
st_length::st_length_impl,
24+
st_area::st_area_impl, st_centroid::st_centroid_impl, st_distance::st_distance_impl,
25+
st_dwithin::st_dwithin_impl, st_intersects::st_intersects_impl, st_length::st_length_impl,
2626
};
2727

2828
pub fn scalar_kernels() -> Vec<(&'static str, ScalarKernelRef)> {
2929
vec![
3030
("st_intersects", st_intersects_impl()),
3131
("st_area", st_area_impl()),
3232
("st_centroid", st_centroid_impl()),
33+
("st_distance", st_distance_impl()),
34+
("st_dwithin", st_dwithin_impl()),
3335
("st_length", st_length_impl()),
3436
("st_lineinterpolatepoint", st_line_interpolate_point_impl()),
3537
]

0 commit comments

Comments
 (0)