Skip to content

Commit 1193bb2

Browse files
authored
Add argument matcher or (#69)
1 parent df1e0b7 commit 1193bb2

File tree

4 files changed

+71
-14
lines changed

4 files changed

+71
-14
lines changed

rust/sedona-functions/src/sd_format.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ impl SedonaScalarKernel for SDFormatDefault {
8181
let matcher = ArgMatcher::new(
8282
vec![
8383
ArgMatcher::is_any(),
84-
ArgMatcher::is_optional(ArgMatcher::is_string()),
84+
ArgMatcher::optional(ArgMatcher::is_string()),
8585
],
8686
formatted_type,
8787
);

rust/sedona-functions/src/st_perimeter.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ pub fn st_perimeter_udf() -> SedonaScalarUDF {
2828
ArgMatcher::new(
2929
vec![
3030
ArgMatcher::is_geometry_or_geography(),
31-
ArgMatcher::is_optional(ArgMatcher::is_boolean()),
32-
ArgMatcher::is_optional(ArgMatcher::is_boolean()),
31+
ArgMatcher::optional(ArgMatcher::is_boolean()),
32+
ArgMatcher::optional(ArgMatcher::is_boolean()),
3333
],
3434
SedonaType::Arrow(DataType::Float64),
3535
),

rust/sedona-functions/src/st_transform.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,12 @@ pub fn st_transform_udf() -> SedonaScalarUDF {
3232
ArgMatcher::new(
3333
vec![
3434
ArgMatcher::is_geometry_or_geography(),
35-
ArgMatcher::is_string(),
36-
ArgMatcher::is_optional(ArgMatcher::is_string()),
37-
ArgMatcher::is_optional(ArgMatcher::is_boolean()),
35+
ArgMatcher::or(vec![ArgMatcher::is_string(), ArgMatcher::is_numeric()]),
36+
ArgMatcher::optional(ArgMatcher::or(vec![
37+
ArgMatcher::is_string(),
38+
ArgMatcher::is_numeric(),
39+
])),
40+
ArgMatcher::optional(ArgMatcher::is_boolean()),
3841
],
3942
SedonaType::Wkb(Edges::Planar, None),
4043
),

rust/sedona-schema/src/matchers.rs

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ impl ArgMatcher {
8585
if arg == &&SedonaType::Arrow(DataType::Null) || matcher.match_type(arg) {
8686
arg_iter.next(); // Consume the argument
8787
continue; // Move to the next matcher
88-
} else if matcher.is_optional() {
88+
} else if matcher.optional() {
8989
continue; // Skip the optional matcher
9090
} else {
9191
return false; // Non-optional matcher failed
9292
}
93-
} else if matcher.is_optional() {
93+
} else if matcher.optional() {
9494
continue; // Skip remaining optional matchers
9595
} else {
9696
return false; // Non-optional matcher failed with no arguments left
@@ -179,11 +179,18 @@ impl ArgMatcher {
179179
}
180180

181181
/// Matches any argument that is optional
182-
pub fn is_optional(
182+
pub fn optional(
183183
matcher: Arc<dyn TypeMatcher + Send + Sync>,
184184
) -> Arc<dyn TypeMatcher + Send + Sync> {
185185
Arc::new(OptionalMatcher { inner: matcher })
186186
}
187+
188+
/// Matches if any of the given matchers match
189+
pub fn or(
190+
matchers: Vec<Arc<dyn TypeMatcher + Send + Sync>>,
191+
) -> Arc<dyn TypeMatcher + Send + Sync> {
192+
Arc::new(OrMatcher { matchers })
193+
}
187194
}
188195

189196
/// A TypeMatcher is a predicate on a [SedonaType]
@@ -198,7 +205,7 @@ pub trait TypeMatcher: Debug {
198205
fn match_type(&self, arg: &SedonaType) -> bool;
199206

200207
/// If this argument is optional, return true
201-
fn is_optional(&self) -> bool {
208+
fn optional(&self) -> bool {
202209
false
203210
}
204211

@@ -244,7 +251,7 @@ impl TypeMatcher for OptionalMatcher {
244251
self.inner.match_type(arg)
245252
}
246253

247-
fn is_optional(&self) -> bool {
254+
fn optional(&self) -> bool {
248255
true
249256
}
250257

@@ -253,6 +260,21 @@ impl TypeMatcher for OptionalMatcher {
253260
}
254261
}
255262

263+
#[derive(Debug)]
264+
struct OrMatcher {
265+
matchers: Vec<Arc<dyn TypeMatcher + Send + Sync>>,
266+
}
267+
268+
impl TypeMatcher for OrMatcher {
269+
fn match_type(&self, arg: &SedonaType) -> bool {
270+
self.matchers.iter().any(|m| m.match_type(arg))
271+
}
272+
273+
fn type_if_null(&self) -> Option<SedonaType> {
274+
None
275+
}
276+
}
277+
256278
#[derive(Debug)]
257279
struct IsGeometryOrGeography {}
258280

@@ -446,8 +468,8 @@ mod tests {
446468
let matcher = ArgMatcher::new(
447469
vec![
448470
ArgMatcher::is_geometry(),
449-
ArgMatcher::is_optional(ArgMatcher::is_boolean()),
450-
ArgMatcher::is_optional(ArgMatcher::is_numeric()),
471+
ArgMatcher::optional(ArgMatcher::is_boolean()),
472+
ArgMatcher::optional(ArgMatcher::is_numeric()),
451473
],
452474
SedonaType::Arrow(DataType::Null),
453475
);
@@ -486,6 +508,38 @@ mod tests {
486508
]));
487509
}
488510

511+
#[test]
512+
fn or_matcher() {
513+
let matcher = ArgMatcher::new(
514+
vec![
515+
ArgMatcher::is_geometry(),
516+
ArgMatcher::or(vec![ArgMatcher::is_boolean(), ArgMatcher::is_numeric()]),
517+
],
518+
SedonaType::Arrow(DataType::Null),
519+
);
520+
521+
// Matches first arg
522+
assert!(matcher.matches(&[WKB_GEOMETRY, SedonaType::Arrow(DataType::Boolean),]));
523+
524+
// Matches second arg
525+
assert!(matcher.matches(&[WKB_GEOMETRY, SedonaType::Arrow(DataType::Int32)]));
526+
527+
// No match when second arg is incorrect type
528+
assert!(!matcher.matches(&[WKB_GEOMETRY, WKB_GEOMETRY]));
529+
530+
// No match when first arg is incorrect type
531+
assert!(!matcher.matches(&[
532+
SedonaType::Arrow(DataType::Boolean),
533+
SedonaType::Arrow(DataType::Boolean)
534+
]));
535+
536+
// Return type if null
537+
assert_eq!(
538+
ArgMatcher::or(vec![ArgMatcher::is_boolean(), ArgMatcher::is_numeric()]).type_if_null(),
539+
None
540+
);
541+
}
542+
489543
#[test]
490544
fn arg_matcher_matches_null() {
491545
for type_matcher in [
@@ -498,7 +552,7 @@ mod tests {
498552
ArgMatcher::is_string(),
499553
ArgMatcher::is_binary(),
500554
ArgMatcher::is_boolean(),
501-
ArgMatcher::is_optional(ArgMatcher::is_numeric()),
555+
ArgMatcher::optional(ArgMatcher::is_numeric()),
502556
] {
503557
let matcher = ArgMatcher::new(vec![type_matcher], SedonaType::Arrow(DataType::Null));
504558
assert!(matcher.matches(&[SedonaType::Arrow(DataType::Null)]));

0 commit comments

Comments
 (0)