Skip to content

Commit 602212d

Browse files
authored
Convert more binary functions to sqlfunc (#32109)
Follow-up to #32083. Derive the SQL function glue code for more binary functions. ### Checklist - [ ] This PR has adequate test coverage / QA involvement has been duly considered. ([trigger-ci for additional test/nightly runs](https://trigger-ci.dev.materialize.com/)) - [ ] This PR has an associated up-to-date [design doc](https://github.com/MaterializeInc/materialize/blob/main/doc/developer/design/README.md), is a design doc ([template](https://github.com/MaterializeInc/materialize/blob/main/doc/developer/design/00000000_template.md)), or is sufficiently small to not require a design. <!-- Reference the design in the description. --> - [ ] If this PR evolves [an existing `$T ⇔ Proto$T` mapping](https://github.com/MaterializeInc/materialize/blob/main/doc/developer/command-and-response-binary-encoding.md) (possibly in a backwards-incompatible way), then it is tagged with a `T-proto` label. - [ ] If this PR will require changes to cloud orchestration or tests, there is a companion cloud PR to account for those changes that is tagged with the release-blocker label ([example](MaterializeInc/cloud#5021)). <!-- Ask in #team-cloud on Slack if you need help preparing the cloud PR. --> - [ ] If this PR includes major [user-facing behavior changes](https://github.com/MaterializeInc/materialize/blob/main/doc/developer/guide-changes.md#what-changes-require-a-release-note), I have pinged the relevant PM to schedule a changelog post. --------- Signed-off-by: Moritz Hoffmann <mh@materialize.com>
1 parent abd4987 commit 602212d

File tree

85 files changed

+6956
-31
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+6956
-31
lines changed

src/expr/src/scalar/func.rs

Lines changed: 672 additions & 13 deletions
Large diffs are not rendered by default.

src/expr/src/scalar/func/binary.rs

Lines changed: 210 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -408,12 +408,25 @@ mod test {
408408
}
409409

410410
#[mz_ore::test]
411-
fn test_equivalence() {
411+
fn test_equivalence_nullable() {
412+
test_equivalence_inner(true);
413+
}
414+
415+
#[mz_ore::test]
416+
fn test_equivalence_non_nullable() {
417+
test_equivalence_inner(false);
418+
}
419+
420+
/// Test the equivalence of the binary functions in the `func` module with their
421+
/// derived sqlfunc implementation. The `input_nullable` parameter determines
422+
/// whether the input colum is marked nullable or not.
423+
fn test_equivalence_inner(input_nullable: bool) {
412424
#[track_caller]
413425
fn check<T: LazyBinaryFunc + std::fmt::Display>(
414426
new: T,
415427
old: BinaryFunc,
416-
column_ty: ColumnType,
428+
column_a_ty: &ColumnType,
429+
column_b_ty: &ColumnType,
417430
) {
418431
assert_eq!(
419432
new.propagates_nulls(),
@@ -429,34 +442,213 @@ mod test {
429442
assert_eq!(new.is_monotone(), old.is_monotone(), "is_monotone mismatch");
430443
assert_eq!(new.is_infix_op(), old.is_infix_op(), "is_infix_op mismatch");
431444
assert_eq!(
432-
new.output_type(column_ty.clone(), column_ty.clone()),
433-
old.output_type(column_ty.clone(), column_ty.clone()),
445+
new.output_type(column_a_ty.clone(), column_b_ty.clone()),
446+
old.output_type(column_a_ty.clone(), column_b_ty.clone()),
434447
"output_type mismatch"
435448
);
436449
assert_eq!(format!("{}", new), format!("{}", old), "format mismatch");
437450
}
438451
let i32_ty = ColumnType {
439-
nullable: true,
452+
nullable: input_nullable,
440453
scalar_type: ScalarType::Int32,
441454
};
442455
let ts_tz_ty = ColumnType {
443-
nullable: true,
456+
nullable: input_nullable,
444457
scalar_type: ScalarType::TimestampTz { precision: None },
445458
};
459+
let time_ty = ColumnType {
460+
nullable: input_nullable,
461+
scalar_type: ScalarType::Time,
462+
};
463+
let interval_ty = ColumnType {
464+
nullable: input_nullable,
465+
scalar_type: ScalarType::Interval,
466+
};
446467

447468
use BinaryFunc as BF;
448469

449-
check(func::AddInt16, BF::AddInt16, i32_ty.clone());
450-
check(func::AddInt32, BF::AddInt32, i32_ty.clone());
451-
check(func::AddInt64, BF::AddInt64, i32_ty.clone());
452-
check(func::AddUint16, BF::AddUInt16, i32_ty.clone());
453-
check(func::AddUint32, BF::AddUInt32, i32_ty.clone());
454-
check(func::AddUint64, BF::AddUInt64, i32_ty.clone());
455-
check(func::AddFloat32, BF::AddFloat32, i32_ty.clone());
456-
check(func::AddFloat64, BF::AddFloat64, i32_ty.clone());
457-
check(func::AddDateTime, BF::AddDateTime, i32_ty.clone());
458-
check(func::AddDateInterval, BF::AddDateInterval, i32_ty.clone());
459-
check(func::AddTimeInterval, BF::AddTimeInterval, ts_tz_ty.clone());
460-
check(func::RoundNumericBinary, BF::RoundNumeric, i32_ty.clone());
470+
// TODO: We're passing unexpected column types to the functions here,
471+
// which works because most don't look at the type. We should fix this
472+
// and pass expected column types.
473+
474+
check(func::AddInt16, BF::AddInt16, &i32_ty, &i32_ty);
475+
check(func::AddInt32, BF::AddInt32, &i32_ty, &i32_ty);
476+
check(func::AddInt64, BF::AddInt64, &i32_ty, &i32_ty);
477+
check(func::AddUint16, BF::AddUInt16, &i32_ty, &i32_ty);
478+
check(func::AddUint32, BF::AddUInt32, &i32_ty, &i32_ty);
479+
check(func::AddUint64, BF::AddUInt64, &i32_ty, &i32_ty);
480+
check(func::AddFloat32, BF::AddFloat32, &i32_ty, &i32_ty);
481+
check(func::AddFloat64, BF::AddFloat64, &i32_ty, &i32_ty);
482+
check(func::AddDateTime, BF::AddDateTime, &i32_ty, &i32_ty);
483+
check(func::AddDateInterval, BF::AddDateInterval, &i32_ty, &i32_ty);
484+
check(
485+
func::AddTimeInterval,
486+
BF::AddTimeInterval,
487+
&ts_tz_ty,
488+
&i32_ty,
489+
);
490+
check(func::RoundNumericBinary, BF::RoundNumeric, &i32_ty, &i32_ty);
491+
check(func::ConvertFrom, BF::ConvertFrom, &i32_ty, &i32_ty);
492+
check(func::Encode, BF::Encode, &i32_ty, &i32_ty);
493+
check(
494+
func::EncodedBytesCharLength,
495+
BF::EncodedBytesCharLength,
496+
&i32_ty,
497+
&i32_ty,
498+
);
499+
check(func::AddNumeric, BF::AddNumeric, &i32_ty, &i32_ty);
500+
check(func::AddInterval, BF::AddInterval, &i32_ty, &i32_ty);
501+
check(func::BitAndInt16, BF::BitAndInt16, &i32_ty, &i32_ty);
502+
check(func::BitAndInt32, BF::BitAndInt32, &i32_ty, &i32_ty);
503+
check(func::BitAndInt64, BF::BitAndInt64, &i32_ty, &i32_ty);
504+
check(func::BitAndUint16, BF::BitAndUInt16, &i32_ty, &i32_ty);
505+
check(func::BitAndUint32, BF::BitAndUInt32, &i32_ty, &i32_ty);
506+
check(func::BitAndUint64, BF::BitAndUInt64, &i32_ty, &i32_ty);
507+
check(func::BitOrInt16, BF::BitOrInt16, &i32_ty, &i32_ty);
508+
check(func::BitOrInt32, BF::BitOrInt32, &i32_ty, &i32_ty);
509+
check(func::BitOrInt64, BF::BitOrInt64, &i32_ty, &i32_ty);
510+
check(func::BitOrUint16, BF::BitOrUInt16, &i32_ty, &i32_ty);
511+
check(func::BitOrUint32, BF::BitOrUInt32, &i32_ty, &i32_ty);
512+
check(func::BitOrUint64, BF::BitOrUInt64, &i32_ty, &i32_ty);
513+
check(func::BitXorInt16, BF::BitXorInt16, &i32_ty, &i32_ty);
514+
check(func::BitXorInt32, BF::BitXorInt32, &i32_ty, &i32_ty);
515+
check(func::BitXorInt64, BF::BitXorInt64, &i32_ty, &i32_ty);
516+
check(func::BitXorUint16, BF::BitXorUInt16, &i32_ty, &i32_ty);
517+
check(func::BitXorUint32, BF::BitXorUInt32, &i32_ty, &i32_ty);
518+
check(func::BitXorUint64, BF::BitXorUInt64, &i32_ty, &i32_ty);
519+
520+
check(
521+
func::BitShiftLeftInt16,
522+
BF::BitShiftLeftInt16,
523+
&i32_ty,
524+
&i32_ty,
525+
);
526+
check(
527+
func::BitShiftLeftInt32,
528+
BF::BitShiftLeftInt32,
529+
&i32_ty,
530+
&i32_ty,
531+
);
532+
check(
533+
func::BitShiftLeftInt64,
534+
BF::BitShiftLeftInt64,
535+
&i32_ty,
536+
&i32_ty,
537+
);
538+
check(
539+
func::BitShiftLeftUint16,
540+
BF::BitShiftLeftUInt16,
541+
&i32_ty,
542+
&i32_ty,
543+
);
544+
check(
545+
func::BitShiftLeftUint32,
546+
BF::BitShiftLeftUInt32,
547+
&i32_ty,
548+
&i32_ty,
549+
);
550+
check(
551+
func::BitShiftLeftUint64,
552+
BF::BitShiftLeftUInt64,
553+
&i32_ty,
554+
&i32_ty,
555+
);
556+
557+
check(
558+
func::BitShiftRightInt16,
559+
BF::BitShiftRightInt16,
560+
&i32_ty,
561+
&i32_ty,
562+
);
563+
check(
564+
func::BitShiftRightInt32,
565+
BF::BitShiftRightInt32,
566+
&i32_ty,
567+
&i32_ty,
568+
);
569+
check(
570+
func::BitShiftRightInt64,
571+
BF::BitShiftRightInt64,
572+
&i32_ty,
573+
&i32_ty,
574+
);
575+
check(
576+
func::BitShiftRightUint16,
577+
BF::BitShiftRightUInt16,
578+
&i32_ty,
579+
&i32_ty,
580+
);
581+
check(
582+
func::BitShiftRightUint32,
583+
BF::BitShiftRightUInt32,
584+
&i32_ty,
585+
&i32_ty,
586+
);
587+
check(
588+
func::BitShiftRightUint64,
589+
BF::BitShiftRightUInt64,
590+
&i32_ty,
591+
&i32_ty,
592+
);
593+
594+
check(func::SubInt16, BF::SubInt16, &i32_ty, &i32_ty);
595+
check(func::SubInt32, BF::SubInt32, &i32_ty, &i32_ty);
596+
check(func::SubInt64, BF::SubInt64, &i32_ty, &i32_ty);
597+
check(func::SubUint16, BF::SubUInt16, &i32_ty, &i32_ty);
598+
check(func::SubUint32, BF::SubUInt32, &i32_ty, &i32_ty);
599+
check(func::SubUint64, BF::SubUInt64, &i32_ty, &i32_ty);
600+
check(func::SubFloat32, BF::SubFloat32, &i32_ty, &i32_ty);
601+
check(func::SubFloat64, BF::SubFloat64, &i32_ty, &i32_ty);
602+
check(func::SubNumeric, BF::SubNumeric, &i32_ty, &i32_ty);
603+
604+
check(func::AgeTimestamp, BF::AgeTimestamp, &i32_ty, &i32_ty);
605+
check(func::AgeTimestamptz, BF::AgeTimestampTz, &i32_ty, &i32_ty);
606+
607+
check(func::SubTimestamp, BF::SubTimestamp, &ts_tz_ty, &i32_ty);
608+
check(func::SubTimestamptz, BF::SubTimestampTz, &ts_tz_ty, &i32_ty);
609+
check(func::SubDate, BF::SubDate, &i32_ty, &i32_ty);
610+
check(func::SubTime, BF::SubTime, &i32_ty, &i32_ty);
611+
check(func::SubInterval, BF::SubInterval, &i32_ty, &i32_ty);
612+
check(func::SubDateInterval, BF::SubDateInterval, &i32_ty, &i32_ty);
613+
check(
614+
func::SubTimeInterval,
615+
BF::SubTimeInterval,
616+
&time_ty,
617+
&interval_ty,
618+
);
619+
620+
check(func::MulInt16, BF::MulInt16, &i32_ty, &i32_ty);
621+
check(func::MulInt32, BF::MulInt32, &i32_ty, &i32_ty);
622+
check(func::MulInt64, BF::MulInt64, &i32_ty, &i32_ty);
623+
check(func::MulUint16, BF::MulUInt16, &i32_ty, &i32_ty);
624+
check(func::MulUint32, BF::MulUInt32, &i32_ty, &i32_ty);
625+
check(func::MulUint64, BF::MulUInt64, &i32_ty, &i32_ty);
626+
check(func::MulFloat32, BF::MulFloat32, &i32_ty, &i32_ty);
627+
check(func::MulFloat64, BF::MulFloat64, &i32_ty, &i32_ty);
628+
check(func::MulNumeric, BF::MulNumeric, &i32_ty, &i32_ty);
629+
check(func::MulInterval, BF::MulInterval, &i32_ty, &i32_ty);
630+
631+
check(func::DivInt16, BF::DivInt16, &i32_ty, &i32_ty);
632+
check(func::DivInt32, BF::DivInt32, &i32_ty, &i32_ty);
633+
check(func::DivInt64, BF::DivInt64, &i32_ty, &i32_ty);
634+
check(func::DivUint16, BF::DivUInt16, &i32_ty, &i32_ty);
635+
check(func::DivUint32, BF::DivUInt32, &i32_ty, &i32_ty);
636+
check(func::DivUint64, BF::DivUInt64, &i32_ty, &i32_ty);
637+
check(func::DivFloat32, BF::DivFloat32, &i32_ty, &i32_ty);
638+
check(func::DivFloat64, BF::DivFloat64, &i32_ty, &i32_ty);
639+
check(func::DivNumeric, BF::DivNumeric, &i32_ty, &i32_ty);
640+
check(func::DivInterval, BF::DivInterval, &i32_ty, &i32_ty);
641+
642+
check(func::ModInt16, BF::ModInt16, &i32_ty, &i32_ty);
643+
check(func::ModInt32, BF::ModInt32, &i32_ty, &i32_ty);
644+
check(func::ModInt64, BF::ModInt64, &i32_ty, &i32_ty);
645+
check(func::ModUint16, BF::ModUInt16, &i32_ty, &i32_ty);
646+
check(func::ModUint32, BF::ModUInt32, &i32_ty, &i32_ty);
647+
check(func::ModUint64, BF::ModUInt64, &i32_ty, &i32_ty);
648+
check(func::ModFloat32, BF::ModFloat32, &i32_ty, &i32_ty);
649+
check(func::ModFloat64, BF::ModFloat64, &i32_ty, &i32_ty);
650+
check(func::ModNumeric, BF::ModNumeric, &i32_ty, &i32_ty);
651+
652+
check(func::ArrayLength, BF::ArrayLength, &i32_ty, &i32_ty);
461653
}
462654
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
---
2+
source: src/expr/src/scalar/func.rs
3+
expression: "#[sqlfunc(\n is_monotone = (true, true),\n output_type = Interval,\n is_infix_op = true,\n sqlname = \"+\",\n propagates_nulls = true\n)]\nfn add_interval<'a>(a: Datum<'a>, b: Datum<'a>) -> Result<Datum<'a>, EvalError> {\n a.unwrap_interval()\n .checked_add(&b.unwrap_interval())\n .ok_or(EvalError::IntervalOutOfRange(format!(\"{a} + {b}\").into()))\n .map(Datum::from)\n}\n"
4+
---
5+
#[derive(
6+
proptest_derive::Arbitrary,
7+
Ord,
8+
PartialOrd,
9+
Clone,
10+
Debug,
11+
Eq,
12+
PartialEq,
13+
serde::Serialize,
14+
serde::Deserialize,
15+
Hash,
16+
mz_lowertest::MzReflect
17+
)]
18+
pub struct AddInterval;
19+
impl<'a> crate::func::binary::EagerBinaryFunc<'a> for AddInterval {
20+
type Input1 = Datum<'a>;
21+
type Input2 = Datum<'a>;
22+
type Output = Result<Datum<'a>, EvalError>;
23+
fn call(
24+
&self,
25+
a: Self::Input1,
26+
b: Self::Input2,
27+
temp_storage: &'a mz_repr::RowArena,
28+
) -> Self::Output {
29+
add_interval(a, b)
30+
}
31+
fn output_type(
32+
&self,
33+
input_type_a: mz_repr::ColumnType,
34+
input_type_b: mz_repr::ColumnType,
35+
) -> mz_repr::ColumnType {
36+
use mz_repr::AsColumnType;
37+
let output = <Interval>::as_column_type();
38+
let propagates_nulls = crate::func::binary::EagerBinaryFunc::propagates_nulls(
39+
self,
40+
);
41+
let nullable = output.nullable;
42+
output
43+
.nullable(
44+
nullable
45+
|| (propagates_nulls
46+
&& (input_type_a.nullable || input_type_b.nullable)),
47+
)
48+
}
49+
fn introduces_nulls(&self) -> bool {
50+
<Interval as ::mz_repr::DatumType<'_, ()>>::nullable()
51+
}
52+
fn is_infix_op(&self) -> bool {
53+
true
54+
}
55+
fn is_monotone(&self) -> (bool, bool) {
56+
(true, true)
57+
}
58+
fn propagates_nulls(&self) -> bool {
59+
true
60+
}
61+
}
62+
impl std::fmt::Display for AddInterval {
63+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
64+
f.write_str("+")
65+
}
66+
}
67+
fn add_interval<'a>(a: Datum<'a>, b: Datum<'a>) -> Result<Datum<'a>, EvalError> {
68+
a.unwrap_interval()
69+
.checked_add(&b.unwrap_interval())
70+
.ok_or(EvalError::IntervalOutOfRange(format!("{a} + {b}").into()))
71+
.map(Datum::from)
72+
}

0 commit comments

Comments
 (0)