Skip to content

Commit 64359d7

Browse files
test: add Rust and Python tests for ST_Relate
1 parent 156993f commit 64359d7

File tree

3 files changed

+96
-3
lines changed

3 files changed

+96
-3
lines changed

c/sedona-geos/src/st_relate.rs

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
// KIND, either express or implied. See the License for the
1515
// specific language governing permissions and limitations
1616
// under the License.
17+
use std::sync::Arc;
1718

1819
use arrow_array::builder::StringBuilder;
20+
use arrow_schema::DataType;
1921
use datafusion_common::error::Result;
2022
use datafusion_common::DataFusionError;
2123
use datafusion_expr::ColumnarValue;
@@ -25,7 +27,6 @@ use sedona_expr::{
2527
scalar_udf::{ScalarKernelRef, SedonaScalarKernel},
2628
};
2729
use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
28-
use std::sync::Arc;
2930

3031
use crate::executor::GeosExecutor;
3132

@@ -41,7 +42,7 @@ impl SedonaScalarKernel for STRelate {
4142
fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
4243
let matcher = ArgMatcher::new(
4344
vec![ArgMatcher::is_geometry(), ArgMatcher::is_geometry()],
44-
SedonaType::Arrow(arrow_schema::DataType::Utf8),
45+
SedonaType::Arrow(DataType::Utf8),
4546
);
4647

4748
matcher.match_args(args)
@@ -75,3 +76,59 @@ impl SedonaScalarKernel for STRelate {
7576
executor.finish(Arc::new(builder.finish()))
7677
}
7778
}
79+
80+
#[cfg(test)]
81+
mod tests {
82+
use arrow_array::{create_array as arrow_array, ArrayRef};
83+
use datafusion_common::ScalarValue;
84+
use rstest::rstest;
85+
use sedona_expr::scalar_udf::SedonaScalarUDF;
86+
use sedona_schema::datatypes::{WKB_GEOMETRY, WKB_VIEW_GEOMETRY};
87+
use sedona_testing::compare::assert_array_equal;
88+
use sedona_testing::create::create_array;
89+
use sedona_testing::testers::ScalarUdfTester;
90+
91+
use super::*;
92+
93+
#[rstest]
94+
fn udf(#[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] sedona_type: SedonaType) {
95+
let udf = SedonaScalarUDF::from_impl("st_relate", st_relate_impl());
96+
let tester = ScalarUdfTester::new(udf.into(), vec![sedona_type.clone(), sedona_type]);
97+
tester.assert_return_type(DataType::Utf8);
98+
99+
// Two disjoint points — DE-9IM should be "FF0FFF0F2"
100+
let result = tester
101+
.invoke_scalar_scalar("POINT (0 0)", "POINT (1 1)")
102+
.unwrap();
103+
tester.assert_scalar_result_equals(result, "FF0FFF0F2");
104+
105+
// NULL inputs should return NULL
106+
let result = tester
107+
.invoke_scalar_scalar(ScalarValue::Null, ScalarValue::Null)
108+
.unwrap();
109+
assert!(result.is_null());
110+
111+
// Array inputs
112+
let lhs = create_array(
113+
&[
114+
Some("POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))"),
115+
Some("POINT (0.5 0.5)"),
116+
None,
117+
],
118+
&WKB_GEOMETRY,
119+
);
120+
let rhs = create_array(
121+
&[
122+
Some("POINT (0.5 0.5)"),
123+
Some("POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))"),
124+
Some("POINT (0 0)"),
125+
],
126+
&WKB_GEOMETRY,
127+
);
128+
129+
// polygon contains point → "0F2FF1FF2"
130+
// point within polygon → "0F2FF1FF2" (same matrix, reversed)
131+
let expected: ArrayRef = arrow_array!(Utf8, [Some("0F2FF1FF2"), Some("0F2FF1FF2"), None]);
132+
assert_array_equal(&tester.invoke_array_array(lhs, rhs).unwrap(), &expected);
133+
}
134+
}

python/sedonadb/tests/functions/test_predicates.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,39 @@ def test_st_overlaps(eng, geom1, geom2, expected):
442442
f"SELECT ST_Overlaps({geom_or_null(geom1)}, {geom_or_null(geom2)})",
443443
expected,
444444
)
445+
@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
446+
@pytest.mark.parametrize(
447+
("geom1", "geom2", "expected"),
448+
[
449+
(None, None, None),
450+
("POINT (0 0)", None, None),
451+
(None, "POINT (0 0)", None),
452+
# Two disjoint points
453+
("POINT (0 0)", "POINT (1 1)", "FF0FFF0F2"),
454+
# Identical points
455+
("POINT (0 0)", "POINT (0 0)", "0FFFFFFF2"),
456+
# Point on boundary of polygon
457+
("POINT (0 0)", "POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))", "FF2F01212"),
458+
# Point inside polygon
459+
("POINT (0.5 0.5)", "POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))", "0F2FF1FF2"),
460+
# Two disjoint polygons
461+
(
462+
"POLYGON ((0 0, 1 0, 1 1, 0 1, 0 0))",
463+
"POLYGON ((5 5, 6 5, 6 6, 5 6, 5 5))",
464+
"FF2FF1212",
465+
),
466+
# Overlapping polygons
467+
(
468+
"POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))",
469+
"POLYGON ((1 1, 3 1, 3 3, 1 3, 1 1))",
470+
"212101212",
471+
),
472+
],
473+
)
474+
def test_st_relate(eng, geom1, geom2, expected):
475+
eng = eng.create_or_skip()
476+
eng.assert_query_result(
477+
f"SELECT ST_Relate({geom_or_null(geom1)}, {geom_or_null(geom2)})",
478+
expected,
479+
)
480+

rust/sedona-functions/src/register.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ pub fn default_function_set() -> FunctionSet {
7676
crate::st_pointzm::st_pointm_udf,
7777
crate::st_pointzm::st_pointz_udf,
7878
crate::st_pointzm::st_pointzm_udf,
79+
crate::st_relate::st_relate_udf,
7980
crate::st_reverse::st_reverse_udf,
8081
crate::st_rotate::st_rotate_udf,
8182
crate::st_rotate::st_rotate_x_udf,
@@ -105,7 +106,6 @@ pub fn default_function_set() -> FunctionSet {
105106
crate::st_xyzm::st_y_udf,
106107
crate::st_xyzm::st_z_udf,
107108
crate::st_zmflag::st_zmflag_udf,
108-
crate::st_relate::st_relate_udf,
109109
);
110110

111111
register_aggregate_udfs!(

0 commit comments

Comments
 (0)