@@ -408,12 +408,25 @@ mod test {
408408 }
409409
410410 #[ mz_ore:: test]
411- fn test_equivalence ( ) {
411+ fn test_equivalence_nullable ( ) {
412+ test_equivalence_inner ( true ) ;
413+ }
414+
415+ #[ mz_ore:: test]
416+ fn test_equivalence_non_nullable ( ) {
417+ test_equivalence_inner ( false ) ;
418+ }
419+
420+ /// Test the equivalence of the binary functions in the `func` module with their
421+ /// derived sqlfunc implementation. The `input_nullable` parameter determines
422+ /// whether the input colum is marked nullable or not.
423+ fn test_equivalence_inner ( input_nullable : bool ) {
412424 #[ track_caller]
413425 fn check < T : LazyBinaryFunc + std:: fmt:: Display > (
414426 new : T ,
415427 old : BinaryFunc ,
416- column_ty : ColumnType ,
428+ column_a_ty : & ColumnType ,
429+ column_b_ty : & ColumnType ,
417430 ) {
418431 assert_eq ! (
419432 new. propagates_nulls( ) ,
@@ -429,34 +442,213 @@ mod test {
429442 assert_eq ! ( new. is_monotone( ) , old. is_monotone( ) , "is_monotone mismatch" ) ;
430443 assert_eq ! ( new. is_infix_op( ) , old. is_infix_op( ) , "is_infix_op mismatch" ) ;
431444 assert_eq ! (
432- new. output_type( column_ty . clone( ) , column_ty . clone( ) ) ,
433- old. output_type( column_ty . clone( ) , column_ty . clone( ) ) ,
445+ new. output_type( column_a_ty . clone( ) , column_b_ty . clone( ) ) ,
446+ old. output_type( column_a_ty . clone( ) , column_b_ty . clone( ) ) ,
434447 "output_type mismatch"
435448 ) ;
436449 assert_eq ! ( format!( "{}" , new) , format!( "{}" , old) , "format mismatch" ) ;
437450 }
438451 let i32_ty = ColumnType {
439- nullable : true ,
452+ nullable : input_nullable ,
440453 scalar_type : ScalarType :: Int32 ,
441454 } ;
442455 let ts_tz_ty = ColumnType {
443- nullable : true ,
456+ nullable : input_nullable ,
444457 scalar_type : ScalarType :: TimestampTz { precision : None } ,
445458 } ;
459+ let time_ty = ColumnType {
460+ nullable : input_nullable,
461+ scalar_type : ScalarType :: Time ,
462+ } ;
463+ let interval_ty = ColumnType {
464+ nullable : input_nullable,
465+ scalar_type : ScalarType :: Interval ,
466+ } ;
446467
447468 use BinaryFunc as BF ;
448469
449- check ( func:: AddInt16 , BF :: AddInt16 , i32_ty. clone ( ) ) ;
450- check ( func:: AddInt32 , BF :: AddInt32 , i32_ty. clone ( ) ) ;
451- check ( func:: AddInt64 , BF :: AddInt64 , i32_ty. clone ( ) ) ;
452- check ( func:: AddUint16 , BF :: AddUInt16 , i32_ty. clone ( ) ) ;
453- check ( func:: AddUint32 , BF :: AddUInt32 , i32_ty. clone ( ) ) ;
454- check ( func:: AddUint64 , BF :: AddUInt64 , i32_ty. clone ( ) ) ;
455- check ( func:: AddFloat32 , BF :: AddFloat32 , i32_ty. clone ( ) ) ;
456- check ( func:: AddFloat64 , BF :: AddFloat64 , i32_ty. clone ( ) ) ;
457- check ( func:: AddDateTime , BF :: AddDateTime , i32_ty. clone ( ) ) ;
458- check ( func:: AddDateInterval , BF :: AddDateInterval , i32_ty. clone ( ) ) ;
459- check ( func:: AddTimeInterval , BF :: AddTimeInterval , ts_tz_ty. clone ( ) ) ;
460- check ( func:: RoundNumericBinary , BF :: RoundNumeric , i32_ty. clone ( ) ) ;
470+ // TODO: We're passing unexpected column types to the functions here,
471+ // which works because most don't look at the type. We should fix this
472+ // and pass expected column types.
473+
474+ check ( func:: AddInt16 , BF :: AddInt16 , & i32_ty, & i32_ty) ;
475+ check ( func:: AddInt32 , BF :: AddInt32 , & i32_ty, & i32_ty) ;
476+ check ( func:: AddInt64 , BF :: AddInt64 , & i32_ty, & i32_ty) ;
477+ check ( func:: AddUint16 , BF :: AddUInt16 , & i32_ty, & i32_ty) ;
478+ check ( func:: AddUint32 , BF :: AddUInt32 , & i32_ty, & i32_ty) ;
479+ check ( func:: AddUint64 , BF :: AddUInt64 , & i32_ty, & i32_ty) ;
480+ check ( func:: AddFloat32 , BF :: AddFloat32 , & i32_ty, & i32_ty) ;
481+ check ( func:: AddFloat64 , BF :: AddFloat64 , & i32_ty, & i32_ty) ;
482+ check ( func:: AddDateTime , BF :: AddDateTime , & i32_ty, & i32_ty) ;
483+ check ( func:: AddDateInterval , BF :: AddDateInterval , & i32_ty, & i32_ty) ;
484+ check (
485+ func:: AddTimeInterval ,
486+ BF :: AddTimeInterval ,
487+ & ts_tz_ty,
488+ & i32_ty,
489+ ) ;
490+ check ( func:: RoundNumericBinary , BF :: RoundNumeric , & i32_ty, & i32_ty) ;
491+ check ( func:: ConvertFrom , BF :: ConvertFrom , & i32_ty, & i32_ty) ;
492+ check ( func:: Encode , BF :: Encode , & i32_ty, & i32_ty) ;
493+ check (
494+ func:: EncodedBytesCharLength ,
495+ BF :: EncodedBytesCharLength ,
496+ & i32_ty,
497+ & i32_ty,
498+ ) ;
499+ check ( func:: AddNumeric , BF :: AddNumeric , & i32_ty, & i32_ty) ;
500+ check ( func:: AddInterval , BF :: AddInterval , & i32_ty, & i32_ty) ;
501+ check ( func:: BitAndInt16 , BF :: BitAndInt16 , & i32_ty, & i32_ty) ;
502+ check ( func:: BitAndInt32 , BF :: BitAndInt32 , & i32_ty, & i32_ty) ;
503+ check ( func:: BitAndInt64 , BF :: BitAndInt64 , & i32_ty, & i32_ty) ;
504+ check ( func:: BitAndUint16 , BF :: BitAndUInt16 , & i32_ty, & i32_ty) ;
505+ check ( func:: BitAndUint32 , BF :: BitAndUInt32 , & i32_ty, & i32_ty) ;
506+ check ( func:: BitAndUint64 , BF :: BitAndUInt64 , & i32_ty, & i32_ty) ;
507+ check ( func:: BitOrInt16 , BF :: BitOrInt16 , & i32_ty, & i32_ty) ;
508+ check ( func:: BitOrInt32 , BF :: BitOrInt32 , & i32_ty, & i32_ty) ;
509+ check ( func:: BitOrInt64 , BF :: BitOrInt64 , & i32_ty, & i32_ty) ;
510+ check ( func:: BitOrUint16 , BF :: BitOrUInt16 , & i32_ty, & i32_ty) ;
511+ check ( func:: BitOrUint32 , BF :: BitOrUInt32 , & i32_ty, & i32_ty) ;
512+ check ( func:: BitOrUint64 , BF :: BitOrUInt64 , & i32_ty, & i32_ty) ;
513+ check ( func:: BitXorInt16 , BF :: BitXorInt16 , & i32_ty, & i32_ty) ;
514+ check ( func:: BitXorInt32 , BF :: BitXorInt32 , & i32_ty, & i32_ty) ;
515+ check ( func:: BitXorInt64 , BF :: BitXorInt64 , & i32_ty, & i32_ty) ;
516+ check ( func:: BitXorUint16 , BF :: BitXorUInt16 , & i32_ty, & i32_ty) ;
517+ check ( func:: BitXorUint32 , BF :: BitXorUInt32 , & i32_ty, & i32_ty) ;
518+ check ( func:: BitXorUint64 , BF :: BitXorUInt64 , & i32_ty, & i32_ty) ;
519+
520+ check (
521+ func:: BitShiftLeftInt16 ,
522+ BF :: BitShiftLeftInt16 ,
523+ & i32_ty,
524+ & i32_ty,
525+ ) ;
526+ check (
527+ func:: BitShiftLeftInt32 ,
528+ BF :: BitShiftLeftInt32 ,
529+ & i32_ty,
530+ & i32_ty,
531+ ) ;
532+ check (
533+ func:: BitShiftLeftInt64 ,
534+ BF :: BitShiftLeftInt64 ,
535+ & i32_ty,
536+ & i32_ty,
537+ ) ;
538+ check (
539+ func:: BitShiftLeftUint16 ,
540+ BF :: BitShiftLeftUInt16 ,
541+ & i32_ty,
542+ & i32_ty,
543+ ) ;
544+ check (
545+ func:: BitShiftLeftUint32 ,
546+ BF :: BitShiftLeftUInt32 ,
547+ & i32_ty,
548+ & i32_ty,
549+ ) ;
550+ check (
551+ func:: BitShiftLeftUint64 ,
552+ BF :: BitShiftLeftUInt64 ,
553+ & i32_ty,
554+ & i32_ty,
555+ ) ;
556+
557+ check (
558+ func:: BitShiftRightInt16 ,
559+ BF :: BitShiftRightInt16 ,
560+ & i32_ty,
561+ & i32_ty,
562+ ) ;
563+ check (
564+ func:: BitShiftRightInt32 ,
565+ BF :: BitShiftRightInt32 ,
566+ & i32_ty,
567+ & i32_ty,
568+ ) ;
569+ check (
570+ func:: BitShiftRightInt64 ,
571+ BF :: BitShiftRightInt64 ,
572+ & i32_ty,
573+ & i32_ty,
574+ ) ;
575+ check (
576+ func:: BitShiftRightUint16 ,
577+ BF :: BitShiftRightUInt16 ,
578+ & i32_ty,
579+ & i32_ty,
580+ ) ;
581+ check (
582+ func:: BitShiftRightUint32 ,
583+ BF :: BitShiftRightUInt32 ,
584+ & i32_ty,
585+ & i32_ty,
586+ ) ;
587+ check (
588+ func:: BitShiftRightUint64 ,
589+ BF :: BitShiftRightUInt64 ,
590+ & i32_ty,
591+ & i32_ty,
592+ ) ;
593+
594+ check ( func:: SubInt16 , BF :: SubInt16 , & i32_ty, & i32_ty) ;
595+ check ( func:: SubInt32 , BF :: SubInt32 , & i32_ty, & i32_ty) ;
596+ check ( func:: SubInt64 , BF :: SubInt64 , & i32_ty, & i32_ty) ;
597+ check ( func:: SubUint16 , BF :: SubUInt16 , & i32_ty, & i32_ty) ;
598+ check ( func:: SubUint32 , BF :: SubUInt32 , & i32_ty, & i32_ty) ;
599+ check ( func:: SubUint64 , BF :: SubUInt64 , & i32_ty, & i32_ty) ;
600+ check ( func:: SubFloat32 , BF :: SubFloat32 , & i32_ty, & i32_ty) ;
601+ check ( func:: SubFloat64 , BF :: SubFloat64 , & i32_ty, & i32_ty) ;
602+ check ( func:: SubNumeric , BF :: SubNumeric , & i32_ty, & i32_ty) ;
603+
604+ check ( func:: AgeTimestamp , BF :: AgeTimestamp , & i32_ty, & i32_ty) ;
605+ check ( func:: AgeTimestamptz , BF :: AgeTimestampTz , & i32_ty, & i32_ty) ;
606+
607+ check ( func:: SubTimestamp , BF :: SubTimestamp , & ts_tz_ty, & i32_ty) ;
608+ check ( func:: SubTimestamptz , BF :: SubTimestampTz , & ts_tz_ty, & i32_ty) ;
609+ check ( func:: SubDate , BF :: SubDate , & i32_ty, & i32_ty) ;
610+ check ( func:: SubTime , BF :: SubTime , & i32_ty, & i32_ty) ;
611+ check ( func:: SubInterval , BF :: SubInterval , & i32_ty, & i32_ty) ;
612+ check ( func:: SubDateInterval , BF :: SubDateInterval , & i32_ty, & i32_ty) ;
613+ check (
614+ func:: SubTimeInterval ,
615+ BF :: SubTimeInterval ,
616+ & time_ty,
617+ & interval_ty,
618+ ) ;
619+
620+ check ( func:: MulInt16 , BF :: MulInt16 , & i32_ty, & i32_ty) ;
621+ check ( func:: MulInt32 , BF :: MulInt32 , & i32_ty, & i32_ty) ;
622+ check ( func:: MulInt64 , BF :: MulInt64 , & i32_ty, & i32_ty) ;
623+ check ( func:: MulUint16 , BF :: MulUInt16 , & i32_ty, & i32_ty) ;
624+ check ( func:: MulUint32 , BF :: MulUInt32 , & i32_ty, & i32_ty) ;
625+ check ( func:: MulUint64 , BF :: MulUInt64 , & i32_ty, & i32_ty) ;
626+ check ( func:: MulFloat32 , BF :: MulFloat32 , & i32_ty, & i32_ty) ;
627+ check ( func:: MulFloat64 , BF :: MulFloat64 , & i32_ty, & i32_ty) ;
628+ check ( func:: MulNumeric , BF :: MulNumeric , & i32_ty, & i32_ty) ;
629+ check ( func:: MulInterval , BF :: MulInterval , & i32_ty, & i32_ty) ;
630+
631+ check ( func:: DivInt16 , BF :: DivInt16 , & i32_ty, & i32_ty) ;
632+ check ( func:: DivInt32 , BF :: DivInt32 , & i32_ty, & i32_ty) ;
633+ check ( func:: DivInt64 , BF :: DivInt64 , & i32_ty, & i32_ty) ;
634+ check ( func:: DivUint16 , BF :: DivUInt16 , & i32_ty, & i32_ty) ;
635+ check ( func:: DivUint32 , BF :: DivUInt32 , & i32_ty, & i32_ty) ;
636+ check ( func:: DivUint64 , BF :: DivUInt64 , & i32_ty, & i32_ty) ;
637+ check ( func:: DivFloat32 , BF :: DivFloat32 , & i32_ty, & i32_ty) ;
638+ check ( func:: DivFloat64 , BF :: DivFloat64 , & i32_ty, & i32_ty) ;
639+ check ( func:: DivNumeric , BF :: DivNumeric , & i32_ty, & i32_ty) ;
640+ check ( func:: DivInterval , BF :: DivInterval , & i32_ty, & i32_ty) ;
641+
642+ check ( func:: ModInt16 , BF :: ModInt16 , & i32_ty, & i32_ty) ;
643+ check ( func:: ModInt32 , BF :: ModInt32 , & i32_ty, & i32_ty) ;
644+ check ( func:: ModInt64 , BF :: ModInt64 , & i32_ty, & i32_ty) ;
645+ check ( func:: ModUint16 , BF :: ModUInt16 , & i32_ty, & i32_ty) ;
646+ check ( func:: ModUint32 , BF :: ModUInt32 , & i32_ty, & i32_ty) ;
647+ check ( func:: ModUint64 , BF :: ModUInt64 , & i32_ty, & i32_ty) ;
648+ check ( func:: ModFloat32 , BF :: ModFloat32 , & i32_ty, & i32_ty) ;
649+ check ( func:: ModFloat64 , BF :: ModFloat64 , & i32_ty, & i32_ty) ;
650+ check ( func:: ModNumeric , BF :: ModNumeric , & i32_ty, & i32_ty) ;
651+
652+ check ( func:: ArrayLength , BF :: ArrayLength , & i32_ty, & i32_ty) ;
461653 }
462654}
0 commit comments