Skip to content

Commit b380e30

Browse files
authored
[#61] Accept SRIDs as args to ST_Transform (#62)
1 parent 3d01227 commit b380e30

File tree

1 file changed

+181
-13
lines changed

1 file changed

+181
-13
lines changed

c/sedona-proj/src/st_transform.rs

Lines changed: 181 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use sedona_geometry::transform::{transform, CachingCrsEngine, CrsEngine, CrsTran
2626
use sedona_geometry::wkb_factory::WKB_MIN_PROBABLE_BYTES;
2727
use sedona_schema::crs::deserialize_crs;
2828
use sedona_schema::datatypes::{Edges, SedonaType};
29+
use sedona_schema::matchers::ArgMatcher;
2930
use std::cell::OnceCell;
3031
use std::rc::Rc;
3132
use std::sync::{Arc, RwLock};
@@ -135,7 +136,9 @@ fn define_arg_indexes(arg_types: &[SedonaType], indexes: &mut TransformArgIndexe
135136
indexes.first_crs = 1;
136137

137138
for (i, arg_type) in arg_types.iter().enumerate().skip(2) {
138-
if *arg_type == SedonaType::Arrow(DataType::Utf8) {
139+
if ArgMatcher::is_numeric().match_type(arg_type)
140+
|| ArgMatcher::is_string().match_type(arg_type)
141+
{
139142
indexes.second_crs = Some(i);
140143
} else if *arg_type == SedonaType::Arrow(DataType::Boolean) {
141144
indexes.lenient = Some(i);
@@ -154,17 +157,41 @@ impl SedonaScalarKernel for STTransform {
154157
arg_types: &[SedonaType],
155158
scalar_args: &[Option<&ScalarValue>],
156159
) -> Result<Option<SedonaType>> {
160+
let matcher = ArgMatcher::new(
161+
vec![
162+
ArgMatcher::is_geometry_or_geography(),
163+
ArgMatcher::or(vec![ArgMatcher::is_numeric(), ArgMatcher::is_string()]),
164+
ArgMatcher::optional(ArgMatcher::or(vec![
165+
ArgMatcher::is_numeric(),
166+
ArgMatcher::is_string(),
167+
])),
168+
ArgMatcher::optional(ArgMatcher::is_boolean()),
169+
],
170+
SedonaType::Wkb(Edges::Planar, None),
171+
);
172+
173+
if !matcher.matches(arg_types) {
174+
return Ok(None);
175+
}
176+
157177
let mut indexes = TransformArgIndexes::new();
158178
define_arg_indexes(arg_types, &mut indexes);
159179

160-
let to_crs_opt = if let Some(second_crs_index) = indexes.second_crs {
180+
let scalar_arg_opt = if let Some(second_crs_index) = indexes.second_crs {
161181
scalar_args.get(second_crs_index).unwrap()
162182
} else {
163183
scalar_args.get(indexes.first_crs).unwrap()
164184
};
165185

166-
match to_crs_opt {
167-
Some(ScalarValue::Utf8(Some(to_crs))) => {
186+
let crs_str_opt = if let Some(scalar_crs) = scalar_arg_opt {
187+
to_crs_str(scalar_crs)
188+
} else {
189+
None
190+
};
191+
192+
// If there is no CRS argument, we cannot determine the return type.
193+
match crs_str_opt {
194+
Some(to_crs) => {
168195
let val = serde_json::Value::String(to_crs.to_string());
169196
let crs = deserialize_crs(&val)?;
170197
Ok(Some(SedonaType::Wkb(Edges::Planar, crs)))
@@ -187,16 +214,18 @@ impl SedonaScalarKernel for STTransform {
187214
let mut indexes = TransformArgIndexes::new();
188215
define_arg_indexes(arg_types, &mut indexes);
189216

190-
let first_crs = get_scalar_str(args, indexes.first_crs).ok_or_else(|| {
191-
DataFusionError::Execution("First argument must be a scalar string".into())
217+
let first_crs = get_crs_str(args, indexes.first_crs).ok_or_else(|| {
218+
DataFusionError::Execution(
219+
"First CRS argument must be a string or numeric scalar".to_string(),
220+
)
192221
})?;
193222

194223
let lenient = indexes
195224
.lenient
196225
.is_some_and(|i| get_scalar_bool(args, i).unwrap_or(false));
197226

198227
let second_crs = if let Some(second_crs_index) = indexes.second_crs {
199-
get_scalar_str(args, second_crs_index)
228+
get_crs_str(args, second_crs_index)
200229
} else {
201230
None
202231
};
@@ -270,12 +299,23 @@ fn parse_source_crs(source_type: &SedonaType) -> Result<Option<String>> {
270299
}
271300
}
272301

273-
fn get_scalar_str(args: &[ColumnarValue], index: usize) -> Option<String> {
274-
if let Some(ColumnarValue::Scalar(ScalarValue::Utf8(opt_str))) = args.get(index) {
275-
opt_str.clone()
276-
} else {
277-
None
302+
fn to_crs_str(scalar_arg: &ScalarValue) -> Option<String> {
303+
if let Ok(ScalarValue::Utf8(Some(crs))) = scalar_arg.cast_to(&DataType::Utf8) {
304+
if crs.chars().all(|c| c.is_ascii_digit()) {
305+
return Some(format!("EPSG:{crs}"));
306+
} else {
307+
return Some(crs);
308+
}
309+
}
310+
311+
None
312+
}
313+
314+
fn get_crs_str(args: &[ColumnarValue], index: usize) -> Option<String> {
315+
if let ColumnarValue::Scalar(scalar_crs) = &args[index] {
316+
return to_crs_str(scalar_crs);
278317
}
318+
None
279319
}
280320

281321
fn get_scalar_bool(args: &[ColumnarValue], index: usize) -> Option<bool> {
@@ -303,6 +343,88 @@ mod tests {
303343
const NAD83ZONE6PROJ: &str = "EPSG:2230";
304344
const WGS84: &str = "EPSG:4326";
305345

346+
#[rstest]
347+
fn invalid_arg_checks() {
348+
let udf: SedonaScalarUDF =
349+
SedonaScalarUDF::from_kernel("st_transform", st_transform_impl());
350+
351+
// No args
352+
let result = udf.return_field_from_args(ReturnFieldArgs {
353+
arg_fields: &[],
354+
scalar_arguments: &[],
355+
});
356+
assert!(
357+
result.is_err()
358+
&& result
359+
.unwrap_err()
360+
.to_string()
361+
.contains("No kernel matching arguments")
362+
);
363+
364+
// Too many args
365+
let arg_types = [
366+
WKB_GEOMETRY,
367+
SedonaType::Arrow(DataType::Utf8),
368+
SedonaType::Arrow(DataType::Utf8),
369+
SedonaType::Arrow(DataType::Boolean),
370+
SedonaType::Arrow(DataType::Int32),
371+
];
372+
let arg_fields: Vec<Arc<Field>> = arg_types
373+
.iter()
374+
.map(|arg_type| Arc::new(arg_type.to_storage_field("", true).unwrap()))
375+
.collect();
376+
let result = udf.return_field_from_args(ReturnFieldArgs {
377+
arg_fields: &arg_fields,
378+
scalar_arguments: &[None, None, None, None, None],
379+
});
380+
assert!(
381+
result.is_err()
382+
&& result
383+
.unwrap_err()
384+
.to_string()
385+
.contains("No kernel matching arguments")
386+
);
387+
388+
// First arg not geometry
389+
let arg_types = [
390+
SedonaType::Arrow(DataType::Utf8),
391+
SedonaType::Arrow(DataType::Utf8),
392+
];
393+
let arg_fields: Vec<Arc<Field>> = arg_types
394+
.iter()
395+
.map(|arg_type| Arc::new(arg_type.to_storage_field("", true).unwrap()))
396+
.collect();
397+
let result = udf.return_field_from_args(ReturnFieldArgs {
398+
arg_fields: &arg_fields,
399+
scalar_arguments: &[None, None],
400+
});
401+
assert!(
402+
result.is_err()
403+
&& result
404+
.unwrap_err()
405+
.to_string()
406+
.contains("No kernel matching arguments")
407+
);
408+
409+
// Second arg not string or numeric
410+
let arg_types = [WKB_GEOMETRY, SedonaType::Arrow(DataType::Boolean)];
411+
let arg_fields: Vec<Arc<Field>> = arg_types
412+
.iter()
413+
.map(|arg_type| Arc::new(arg_type.to_storage_field("", true).unwrap()))
414+
.collect();
415+
let result = udf.return_field_from_args(ReturnFieldArgs {
416+
arg_fields: &arg_fields,
417+
scalar_arguments: &[None, None],
418+
});
419+
assert!(
420+
result.is_err()
421+
&& result
422+
.unwrap_err()
423+
.to_string()
424+
.contains("No kernel matching arguments")
425+
);
426+
}
427+
306428
#[rstest]
307429
fn test_invoke_batch_with_geo_crs() {
308430
// From-CRS pulled from sedona type
@@ -329,6 +451,32 @@ mod tests {
329451
);
330452
}
331453

454+
#[rstest]
455+
fn test_invoke_with_srids() {
456+
// Use an integer SRID for the to CRS
457+
let arg_types = [
458+
SedonaType::Wkb(Edges::Planar, lnglat()),
459+
SedonaType::Arrow(DataType::UInt32),
460+
];
461+
462+
let wkb = create_array(&[None, Some("POINT (79.3871 43.6426)")], &arg_types[0]);
463+
464+
let scalar_args = vec![ScalarValue::UInt32(Some(2230))];
465+
466+
let expected = create_array_value(
467+
&[None, Some("POINT (-21508577.363421552 34067918.06097863)")],
468+
&SedonaType::Wkb(Edges::Planar, get_crs(NAD83ZONE6PROJ)),
469+
);
470+
471+
let (result_type, result_col) =
472+
invoke_udf_test(wkb, scalar_args, arg_types.to_vec()).unwrap();
473+
assert_value_equal(&result_col, &expected);
474+
assert_eq!(
475+
result_type,
476+
SedonaType::Wkb(Edges::Planar, get_crs(NAD83ZONE6PROJ))
477+
);
478+
}
479+
332480
#[rstest]
333481
fn test_invoke_batch_with_lenient() {
334482
let arg_types = [
@@ -372,7 +520,7 @@ mod tests {
372520
}
373521

374522
#[rstest]
375-
fn test_invoke_batch_with_string_source() {
523+
fn test_invoke_batch_with_source_arg() {
376524
let arg_types = [
377525
WKB_GEOMETRY,
378526
SedonaType::Arrow(DataType::Utf8),
@@ -392,6 +540,26 @@ mod tests {
392540
&SedonaType::Wkb(Edges::Planar, Some(get_crs(NAD83ZONE6PROJ).unwrap())),
393541
);
394542

543+
let (result_type, result_col) =
544+
invoke_udf_test(wkb.clone(), scalar_args, arg_types.to_vec()).unwrap();
545+
assert_value_equal(&result_col, &expected);
546+
assert_eq!(
547+
result_type,
548+
SedonaType::Wkb(Edges::Planar, Some(get_crs(NAD83ZONE6PROJ).unwrap()))
549+
);
550+
551+
// Test with integer SRIDs
552+
let arg_types = [
553+
WKB_GEOMETRY,
554+
SedonaType::Arrow(DataType::Int32),
555+
SedonaType::Arrow(DataType::Int32),
556+
];
557+
558+
let scalar_args = vec![
559+
ScalarValue::Int32(Some(4326)),
560+
ScalarValue::Int32(Some(2230)),
561+
];
562+
395563
let (result_type, result_col) =
396564
invoke_udf_test(wkb, scalar_args, arg_types.to_vec()).unwrap();
397565
assert_value_equal(&result_col, &expected);

0 commit comments

Comments
 (0)