diff --git a/datafusion/spark/src/function/datetime/add_months.rs b/datafusion/spark/src/function/datetime/add_months.rs new file mode 100644 index 0000000000000..50c5524fc639e --- /dev/null +++ b/datafusion/spark/src/function/datetime/add_months.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::ops::Add; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, IntervalUnit}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Cast, ColumnarValue, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, Volatility, +}; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkAddMonths { + signature: Signature, +} + +impl Default for SparkAddMonths { + fn default() -> Self { + Self::new() + } +} + +impl SparkAddMonths { + pub fn new() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Date32, DataType::Int32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkAddMonths { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "add_months" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()) + || args + .scalar_arguments + .iter() + .any(|arg| matches!(arg, Some(sv) if sv.is_null())); + + Ok(Arc::new(Field::new( + self.name(), + DataType::Date32, + nullable, + ))) + } + + fn simplify( + &self, + args: Vec, + _info: &SimplifyContext, + ) -> Result { + let [date_arg, months_arg] = take_function_args("add_months", args)?; + + Ok(ExprSimplifyResult::Simplified(date_arg.add(Expr::Cast( + Cast::new( + Box::new(months_arg), + DataType::Interval(IntervalUnit::YearMonth), + ), + )))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("invoke should not be called on a simplified add_months() function") + } +} diff --git a/datafusion/spark/src/function/datetime/mod.rs b/datafusion/spark/src/function/datetime/mod.rs index 849aa20895990..6b016b8552106 100644 --- a/datafusion/spark/src/function/datetime/mod.rs +++ b/datafusion/spark/src/function/datetime/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod add_months; pub mod date_add; pub mod date_sub; pub mod extract; @@ -27,6 +28,7 @@ use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; use std::sync::Arc; +make_udf_function!(add_months::SparkAddMonths, add_months); make_udf_function!(date_add::SparkDateAdd, date_add); make_udf_function!(date_sub::SparkDateSub, date_sub); make_udf_function!(extract::SparkHour, hour); @@ -40,6 +42,11 @@ make_udf_function!(next_day::SparkNextDay, next_day); pub mod expr_fn { use datafusion_functions::export_functions; + export_functions!(( + add_months, + "Returns the date that is months months after start. The function returns NULL if at least one of the input parameters is NULL.", + arg1 arg2 + )); export_functions!(( date_add, "Returns the date that is days days after start. The function returns NULL if at least one of the input parameters is NULL.", @@ -87,6 +94,7 @@ pub mod expr_fn { pub fn functions() -> Vec> { vec![ + add_months(), date_add(), date_sub(), hour(), diff --git a/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt b/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt index cae9b21dd4766..714bc0de33f5c 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/add_months.slt @@ -15,13 +15,38 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT add_months('2016-08-31', 1); -## PySpark 3.5.5 Result: {'add_months(2016-08-31, 1)': datetime.date(2016, 9, 30), 'typeof(add_months(2016-08-31, 1))': 'date', 'typeof(2016-08-31)': 'string', 'typeof(1)': 'int'} -#query -#SELECT add_months('2016-08-31'::string, 1::int); +query D +SELECT add_months('2016-07-30'::date, 1::int); +---- +2016-08-30 + +query D +SELECT add_months('2016-07-30'::date, 0::int); +---- +2016-07-30 + +query D +SELECT add_months('2016-07-30'::date, 10000::int); +---- +2849-11-30 + +query D +SELECT add_months('2016-07-30'::date, -5::int); +---- +2016-02-29 + +# Test with NULL values +query D +SELECT add_months(NULL::date, 1::int); +---- +NULL + +query D +SELECT add_months('2016-07-30'::date, NULL::int); +---- +NULL + +query D +SELECT add_months(NULL::date, NULL::int); +---- +NULL diff --git a/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt b/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt index a2ac7cf2edb11..cb407a6453696 100644 --- a/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt +++ b/datafusion/sqllogictest/test_files/spark/datetime/date_add.slt @@ -41,7 +41,7 @@ SELECT date_add('2016-07-30'::date, arrow_cast(1, 'Int8')); 2016-07-31 query D -SELECT date_sub('2016-07-30'::date, 0::int); +SELECT date_add('2016-07-30'::date, 0::int); ---- 2016-07-30 @@ -51,20 +51,15 @@ SELECT date_add('2016-07-30'::date, 2147483647::int)::int; -2147466637 query I -SELECT date_sub('1969-01-01'::date, 2147483647::int)::int; +SELECT date_add('1969-01-01'::date, 2147483647::int)::int; ---- -2147483284 +2147483282 query D SELECT date_add('2016-07-30'::date, 100000::int); ---- 2290-05-15 -query D -SELECT date_sub('2016-07-30'::date, 100000::int); ----- -1742-10-15 - # Test with negative day values (should subtract days) query D SELECT date_add('2016-07-30'::date, -5::int);