@@ -6,12 +6,13 @@ use datafusion::arrow::array::{
66} ;
77use datafusion:: arrow:: buffer:: ScalarBuffer ;
88use datafusion:: arrow:: datatypes:: { DataType , IntervalUnit , TimeUnit } ;
9+ use datafusion:: common:: internal_err;
910use datafusion:: error:: DataFusionError ;
1011use datafusion:: logical_expr:: function:: AccumulatorArgs ;
1112use datafusion:: logical_expr:: simplify:: { ExprSimplifyResult , SimplifyInfo } ;
1213use datafusion:: logical_expr:: {
1314 AggregateUDF , AggregateUDFImpl , Expr , ScalarUDF , ScalarUDFImpl , Signature , TypeSignature ,
14- Volatility ,
15+ Volatility , TIMEZONE_WILDCARD ,
1516} ;
1617use datafusion:: physical_plan:: { Accumulator , ColumnarValue } ;
1718use datafusion:: scalar:: ScalarValue ;
@@ -457,6 +458,7 @@ struct DateAddSub {
457458
458459impl DateAddSub {
459460 pub fn new ( is_add : bool ) -> DateAddSub {
461+ let tz_wildcard: Arc < str > = Arc :: from ( TIMEZONE_WILDCARD ) ;
460462 DateAddSub {
461463 is_add,
462464 signature : Signature {
@@ -473,6 +475,22 @@ impl DateAddSub {
473475 DataType :: Timestamp ( TimeUnit :: Nanosecond , None ) ,
474476 DataType :: Interval ( IntervalUnit :: MonthDayNano ) ,
475477 ] ) ,
478+ // We wanted this for NOW(), which has "+00:00" time zone. Using
479+ // TIMEZONE_WILDCARD to favor DST-related questions over "UTC" == "+00:00"
480+ // questions. MySQL doesn't have a timezone as this function is applied, and we
481+ // simply invoke DF's date + interval behavior.
482+ TypeSignature :: Exact ( vec![
483+ DataType :: Timestamp ( TimeUnit :: Nanosecond , Some ( tz_wildcard. clone( ) ) ) ,
484+ DataType :: Interval ( IntervalUnit :: YearMonth ) ,
485+ ] ) ,
486+ TypeSignature :: Exact ( vec![
487+ DataType :: Timestamp ( TimeUnit :: Nanosecond , Some ( tz_wildcard. clone( ) ) ) ,
488+ DataType :: Interval ( IntervalUnit :: DayTime ) ,
489+ ] ) ,
490+ TypeSignature :: Exact ( vec![
491+ DataType :: Timestamp ( TimeUnit :: Nanosecond , Some ( tz_wildcard) ) ,
492+ DataType :: Interval ( IntervalUnit :: MonthDayNano ) ,
493+ ] ) ,
476494 ] ) ,
477495 volatility : Volatility :: Immutable ,
478496 } ,
@@ -505,18 +523,22 @@ impl ScalarUDFImpl for DateAddSub {
505523 fn signature ( & self ) -> & Signature {
506524 & self . signature
507525 }
508- fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType , DataFusionError > {
509- Ok ( DataType :: Timestamp ( TimeUnit :: Nanosecond , None ) )
526+ fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType , DataFusionError > {
527+ if arg_types. len ( ) != 2 {
528+ return Err ( DataFusionError :: Internal ( format ! ( "DateAddSub return_type expects 2 arguments, got {:?}" , arg_types) ) ) ;
529+ }
530+ match ( & arg_types[ 0 ] , & arg_types[ 1 ] ) {
531+ ( ts@DataType :: Timestamp ( _, _) , DataType :: Interval ( _) ) => Ok ( ts. clone ( ) ) ,
532+ _ => Err ( DataFusionError :: Internal ( format ! ( "DateAddSub return_type expects Timestamp and Interval arguments, got {:?}" , arg_types) ) ) ,
533+ }
510534 }
511535 fn invoke ( & self , inputs : & [ ColumnarValue ] ) -> Result < ColumnarValue , DataFusionError > {
512536 use datafusion:: arrow:: compute:: kernels:: numeric:: add;
513537 use datafusion:: arrow:: compute:: kernels:: numeric:: sub;
514538 assert_eq ! ( inputs. len( ) , 2 ) ;
515539 // DF 42.2.0 already has date + interval or date - interval. Note that `add` and `sub` are
516540 // public (defined in arrow_arith), while timestamp-specific functions they invoke,
517- // `arithmetic_op` and then `timestamp_op::<TimestampNanosecondType>`, are not.
518- //
519- // TODO upgrade DF: Double-check that the TypeSignature is actually enforced.
541+ // Arrow's `arithmetic_op` and then `timestamp_op::<TimestampNanosecondType>`, are not.
520542 datafusion:: physical_expr_common:: datum:: apply (
521543 & inputs[ 0 ] ,
522544 & inputs[ 1 ] ,
0 commit comments