@@ -26,6 +26,7 @@ use sedona_geometry::transform::{transform, CachingCrsEngine, CrsEngine, CrsTran
2626use sedona_geometry:: wkb_factory:: WKB_MIN_PROBABLE_BYTES ;
2727use sedona_schema:: crs:: deserialize_crs;
2828use sedona_schema:: datatypes:: { Edges , SedonaType } ;
29+ use sedona_schema:: matchers:: ArgMatcher ;
2930use std:: cell:: OnceCell ;
3031use std:: rc:: Rc ;
3132use std:: sync:: { Arc , RwLock } ;
@@ -135,7 +136,9 @@ fn define_arg_indexes(arg_types: &[SedonaType], indexes: &mut TransformArgIndexe
135136 indexes. first_crs = 1 ;
136137
137138 for ( i, arg_type) in arg_types. iter ( ) . enumerate ( ) . skip ( 2 ) {
138- if * arg_type == SedonaType :: Arrow ( DataType :: Utf8 ) {
139+ if ArgMatcher :: is_numeric ( ) . match_type ( arg_type)
140+ || ArgMatcher :: is_string ( ) . match_type ( arg_type)
141+ {
139142 indexes. second_crs = Some ( i) ;
140143 } else if * arg_type == SedonaType :: Arrow ( DataType :: Boolean ) {
141144 indexes. lenient = Some ( i) ;
@@ -154,17 +157,41 @@ impl SedonaScalarKernel for STTransform {
154157 arg_types : & [ SedonaType ] ,
155158 scalar_args : & [ Option < & ScalarValue > ] ,
156159 ) -> Result < Option < SedonaType > > {
160+ let matcher = ArgMatcher :: new (
161+ vec ! [
162+ ArgMatcher :: is_geometry_or_geography( ) ,
163+ ArgMatcher :: or( vec![ ArgMatcher :: is_numeric( ) , ArgMatcher :: is_string( ) ] ) ,
164+ ArgMatcher :: optional( ArgMatcher :: or( vec![
165+ ArgMatcher :: is_numeric( ) ,
166+ ArgMatcher :: is_string( ) ,
167+ ] ) ) ,
168+ ArgMatcher :: optional( ArgMatcher :: is_boolean( ) ) ,
169+ ] ,
170+ SedonaType :: Wkb ( Edges :: Planar , None ) ,
171+ ) ;
172+
173+ if !matcher. matches ( arg_types) {
174+ return Ok ( None ) ;
175+ }
176+
157177 let mut indexes = TransformArgIndexes :: new ( ) ;
158178 define_arg_indexes ( arg_types, & mut indexes) ;
159179
160- let to_crs_opt = if let Some ( second_crs_index) = indexes. second_crs {
180+ let scalar_arg_opt = if let Some ( second_crs_index) = indexes. second_crs {
161181 scalar_args. get ( second_crs_index) . unwrap ( )
162182 } else {
163183 scalar_args. get ( indexes. first_crs ) . unwrap ( )
164184 } ;
165185
166- match to_crs_opt {
167- Some ( ScalarValue :: Utf8 ( Some ( to_crs) ) ) => {
186+ let crs_str_opt = if let Some ( scalar_crs) = scalar_arg_opt {
187+ to_crs_str ( scalar_crs)
188+ } else {
189+ None
190+ } ;
191+
192+ // If there is no CRS argument, we cannot determine the return type.
193+ match crs_str_opt {
194+ Some ( to_crs) => {
168195 let val = serde_json:: Value :: String ( to_crs. to_string ( ) ) ;
169196 let crs = deserialize_crs ( & val) ?;
170197 Ok ( Some ( SedonaType :: Wkb ( Edges :: Planar , crs) ) )
@@ -187,16 +214,18 @@ impl SedonaScalarKernel for STTransform {
187214 let mut indexes = TransformArgIndexes :: new ( ) ;
188215 define_arg_indexes ( arg_types, & mut indexes) ;
189216
190- let first_crs = get_scalar_str ( args, indexes. first_crs ) . ok_or_else ( || {
191- DataFusionError :: Execution ( "First argument must be a scalar string" . into ( ) )
217+ let first_crs = get_crs_str ( args, indexes. first_crs ) . ok_or_else ( || {
218+ DataFusionError :: Execution (
219+ "First CRS argument must be a string or numeric scalar" . to_string ( ) ,
220+ )
192221 } ) ?;
193222
194223 let lenient = indexes
195224 . lenient
196225 . is_some_and ( |i| get_scalar_bool ( args, i) . unwrap_or ( false ) ) ;
197226
198227 let second_crs = if let Some ( second_crs_index) = indexes. second_crs {
199- get_scalar_str ( args, second_crs_index)
228+ get_crs_str ( args, second_crs_index)
200229 } else {
201230 None
202231 } ;
@@ -270,12 +299,23 @@ fn parse_source_crs(source_type: &SedonaType) -> Result<Option<String>> {
270299 }
271300}
272301
273- fn get_scalar_str ( args : & [ ColumnarValue ] , index : usize ) -> Option < String > {
274- if let Some ( ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( opt_str) ) ) = args. get ( index) {
275- opt_str. clone ( )
276- } else {
277- None
302+ fn to_crs_str ( scalar_arg : & ScalarValue ) -> Option < String > {
303+ if let Ok ( ScalarValue :: Utf8 ( Some ( crs) ) ) = scalar_arg. cast_to ( & DataType :: Utf8 ) {
304+ if crs. chars ( ) . all ( |c| c. is_ascii_digit ( ) ) {
305+ return Some ( format ! ( "EPSG:{crs}" ) ) ;
306+ } else {
307+ return Some ( crs) ;
308+ }
309+ }
310+
311+ None
312+ }
313+
314+ fn get_crs_str ( args : & [ ColumnarValue ] , index : usize ) -> Option < String > {
315+ if let ColumnarValue :: Scalar ( scalar_crs) = & args[ index] {
316+ return to_crs_str ( scalar_crs) ;
278317 }
318+ None
279319}
280320
281321fn get_scalar_bool ( args : & [ ColumnarValue ] , index : usize ) -> Option < bool > {
@@ -303,6 +343,88 @@ mod tests {
303343 const NAD83ZONE6PROJ : & str = "EPSG:2230" ;
304344 const WGS84 : & str = "EPSG:4326" ;
305345
346+ #[ rstest]
347+ fn invalid_arg_checks ( ) {
348+ let udf: SedonaScalarUDF =
349+ SedonaScalarUDF :: from_kernel ( "st_transform" , st_transform_impl ( ) ) ;
350+
351+ // No args
352+ let result = udf. return_field_from_args ( ReturnFieldArgs {
353+ arg_fields : & [ ] ,
354+ scalar_arguments : & [ ] ,
355+ } ) ;
356+ assert ! (
357+ result. is_err( )
358+ && result
359+ . unwrap_err( )
360+ . to_string( )
361+ . contains( "No kernel matching arguments" )
362+ ) ;
363+
364+ // Too many args
365+ let arg_types = [
366+ WKB_GEOMETRY ,
367+ SedonaType :: Arrow ( DataType :: Utf8 ) ,
368+ SedonaType :: Arrow ( DataType :: Utf8 ) ,
369+ SedonaType :: Arrow ( DataType :: Boolean ) ,
370+ SedonaType :: Arrow ( DataType :: Int32 ) ,
371+ ] ;
372+ let arg_fields: Vec < Arc < Field > > = arg_types
373+ . iter ( )
374+ . map ( |arg_type| Arc :: new ( arg_type. to_storage_field ( "" , true ) . unwrap ( ) ) )
375+ . collect ( ) ;
376+ let result = udf. return_field_from_args ( ReturnFieldArgs {
377+ arg_fields : & arg_fields,
378+ scalar_arguments : & [ None , None , None , None , None ] ,
379+ } ) ;
380+ assert ! (
381+ result. is_err( )
382+ && result
383+ . unwrap_err( )
384+ . to_string( )
385+ . contains( "No kernel matching arguments" )
386+ ) ;
387+
388+ // First arg not geometry
389+ let arg_types = [
390+ SedonaType :: Arrow ( DataType :: Utf8 ) ,
391+ SedonaType :: Arrow ( DataType :: Utf8 ) ,
392+ ] ;
393+ let arg_fields: Vec < Arc < Field > > = arg_types
394+ . iter ( )
395+ . map ( |arg_type| Arc :: new ( arg_type. to_storage_field ( "" , true ) . unwrap ( ) ) )
396+ . collect ( ) ;
397+ let result = udf. return_field_from_args ( ReturnFieldArgs {
398+ arg_fields : & arg_fields,
399+ scalar_arguments : & [ None , None ] ,
400+ } ) ;
401+ assert ! (
402+ result. is_err( )
403+ && result
404+ . unwrap_err( )
405+ . to_string( )
406+ . contains( "No kernel matching arguments" )
407+ ) ;
408+
409+ // Second arg not string or numeric
410+ let arg_types = [ WKB_GEOMETRY , SedonaType :: Arrow ( DataType :: Boolean ) ] ;
411+ let arg_fields: Vec < Arc < Field > > = arg_types
412+ . iter ( )
413+ . map ( |arg_type| Arc :: new ( arg_type. to_storage_field ( "" , true ) . unwrap ( ) ) )
414+ . collect ( ) ;
415+ let result = udf. return_field_from_args ( ReturnFieldArgs {
416+ arg_fields : & arg_fields,
417+ scalar_arguments : & [ None , None ] ,
418+ } ) ;
419+ assert ! (
420+ result. is_err( )
421+ && result
422+ . unwrap_err( )
423+ . to_string( )
424+ . contains( "No kernel matching arguments" )
425+ ) ;
426+ }
427+
306428 #[ rstest]
307429 fn test_invoke_batch_with_geo_crs ( ) {
308430 // From-CRS pulled from sedona type
@@ -329,6 +451,32 @@ mod tests {
329451 ) ;
330452 }
331453
454+ #[ rstest]
455+ fn test_invoke_with_srids ( ) {
456+ // Use an integer SRID for the to CRS
457+ let arg_types = [
458+ SedonaType :: Wkb ( Edges :: Planar , lnglat ( ) ) ,
459+ SedonaType :: Arrow ( DataType :: UInt32 ) ,
460+ ] ;
461+
462+ let wkb = create_array ( & [ None , Some ( "POINT (79.3871 43.6426)" ) ] , & arg_types[ 0 ] ) ;
463+
464+ let scalar_args = vec ! [ ScalarValue :: UInt32 ( Some ( 2230 ) ) ] ;
465+
466+ let expected = create_array_value (
467+ & [ None , Some ( "POINT (-21508577.363421552 34067918.06097863)" ) ] ,
468+ & SedonaType :: Wkb ( Edges :: Planar , get_crs ( NAD83ZONE6PROJ ) ) ,
469+ ) ;
470+
471+ let ( result_type, result_col) =
472+ invoke_udf_test ( wkb, scalar_args, arg_types. to_vec ( ) ) . unwrap ( ) ;
473+ assert_value_equal ( & result_col, & expected) ;
474+ assert_eq ! (
475+ result_type,
476+ SedonaType :: Wkb ( Edges :: Planar , get_crs( NAD83ZONE6PROJ ) )
477+ ) ;
478+ }
479+
332480 #[ rstest]
333481 fn test_invoke_batch_with_lenient ( ) {
334482 let arg_types = [
@@ -372,7 +520,7 @@ mod tests {
372520 }
373521
374522 #[ rstest]
375- fn test_invoke_batch_with_string_source ( ) {
523+ fn test_invoke_batch_with_source_arg ( ) {
376524 let arg_types = [
377525 WKB_GEOMETRY ,
378526 SedonaType :: Arrow ( DataType :: Utf8 ) ,
@@ -392,6 +540,26 @@ mod tests {
392540 & SedonaType :: Wkb ( Edges :: Planar , Some ( get_crs ( NAD83ZONE6PROJ ) . unwrap ( ) ) ) ,
393541 ) ;
394542
543+ let ( result_type, result_col) =
544+ invoke_udf_test ( wkb. clone ( ) , scalar_args, arg_types. to_vec ( ) ) . unwrap ( ) ;
545+ assert_value_equal ( & result_col, & expected) ;
546+ assert_eq ! (
547+ result_type,
548+ SedonaType :: Wkb ( Edges :: Planar , Some ( get_crs( NAD83ZONE6PROJ ) . unwrap( ) ) )
549+ ) ;
550+
551+ // Test with integer SRIDs
552+ let arg_types = [
553+ WKB_GEOMETRY ,
554+ SedonaType :: Arrow ( DataType :: Int32 ) ,
555+ SedonaType :: Arrow ( DataType :: Int32 ) ,
556+ ] ;
557+
558+ let scalar_args = vec ! [
559+ ScalarValue :: Int32 ( Some ( 4326 ) ) ,
560+ ScalarValue :: Int32 ( Some ( 2230 ) ) ,
561+ ] ;
562+
395563 let ( result_type, result_col) =
396564 invoke_udf_test ( wkb, scalar_args, arg_types. to_vec ( ) ) . unwrap ( ) ;
397565 assert_value_equal ( & result_col, & expected) ;
0 commit comments