1414// KIND, either express or implied. See the License for the
1515// specific language governing permissions and limitations
1616// under the License.
17+ use std:: sync:: Arc ;
1718
1819use arrow_array:: builder:: StringBuilder ;
20+ use arrow_schema:: DataType ;
1921use datafusion_common:: error:: Result ;
2022use datafusion_common:: DataFusionError ;
2123use datafusion_expr:: ColumnarValue ;
@@ -25,7 +27,6 @@ use sedona_expr::{
2527 scalar_udf:: { ScalarKernelRef , SedonaScalarKernel } ,
2628} ;
2729use sedona_schema:: { datatypes:: SedonaType , matchers:: ArgMatcher } ;
28- use std:: sync:: Arc ;
2930
3031use crate :: executor:: GeosExecutor ;
3132
@@ -41,7 +42,7 @@ impl SedonaScalarKernel for STRelate {
4142 fn return_type ( & self , args : & [ SedonaType ] ) -> Result < Option < SedonaType > > {
4243 let matcher = ArgMatcher :: new (
4344 vec ! [ ArgMatcher :: is_geometry( ) , ArgMatcher :: is_geometry( ) ] ,
44- SedonaType :: Arrow ( arrow_schema :: DataType :: Utf8 ) ,
45+ SedonaType :: Arrow ( DataType :: Utf8 ) ,
4546 ) ;
4647
4748 matcher. match_args ( args)
@@ -75,3 +76,59 @@ impl SedonaScalarKernel for STRelate {
7576 executor. finish ( Arc :: new ( builder. finish ( ) ) )
7677 }
7778}
79+
80+ #[ cfg( test) ]
81+ mod tests {
82+ use arrow_array:: { create_array as arrow_array, ArrayRef } ;
83+ use datafusion_common:: ScalarValue ;
84+ use rstest:: rstest;
85+ use sedona_expr:: scalar_udf:: SedonaScalarUDF ;
86+ use sedona_schema:: datatypes:: { WKB_GEOMETRY , WKB_VIEW_GEOMETRY } ;
87+ use sedona_testing:: compare:: assert_array_equal;
88+ use sedona_testing:: create:: create_array;
89+ use sedona_testing:: testers:: ScalarUdfTester ;
90+
91+ use super :: * ;
92+
93+ #[ rstest]
94+ fn udf ( #[ values( WKB_GEOMETRY , WKB_VIEW_GEOMETRY ) ] sedona_type : SedonaType ) {
95+ let udf = SedonaScalarUDF :: from_impl ( "st_relate" , st_relate_impl ( ) ) ;
96+ let tester = ScalarUdfTester :: new ( udf. into ( ) , vec ! [ sedona_type. clone( ) , sedona_type] ) ;
97+ tester. assert_return_type ( DataType :: Utf8 ) ;
98+
99+ // Two disjoint points — DE-9IM should be "FF0FFF0F2"
100+ let result = tester
101+ . invoke_scalar_scalar ( "POINT (0 0)" , "POINT (1 1)" )
102+ . unwrap ( ) ;
103+ tester. assert_scalar_result_equals ( result, "FF0FFF0F2" ) ;
104+
105+ // NULL inputs should return NULL
106+ let result = tester
107+ . invoke_scalar_scalar ( ScalarValue :: Null , ScalarValue :: Null )
108+ . unwrap ( ) ;
109+ assert ! ( result. is_null( ) ) ;
110+
111+ // Array inputs
112+ let lhs = create_array (
113+ & [
114+ Some ( "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))" ) ,
115+ Some ( "POINT (0.5 0.5)" ) ,
116+ None ,
117+ ] ,
118+ & WKB_GEOMETRY ,
119+ ) ;
120+ let rhs = create_array (
121+ & [
122+ Some ( "POINT (0.5 0.5)" ) ,
123+ Some ( "POLYGON ((0 0, 0 1, 1 1, 1 0, 0 0))" ) ,
124+ Some ( "POINT (0 0)" ) ,
125+ ] ,
126+ & WKB_GEOMETRY ,
127+ ) ;
128+
129+ // polygon contains point → "0F2FF1FF2"
130+ // point within polygon → "0F2FF1FF2" (same matrix, reversed)
131+ let expected: ArrayRef = arrow_array ! ( Utf8 , [ Some ( "0F2FF1FF2" ) , Some ( "0F2FF1FF2" ) , None ] ) ;
132+ assert_array_equal ( & tester. invoke_array_array ( lhs, rhs) . unwrap ( ) , & expected) ;
133+ }
134+ }
0 commit comments