Skip to content

Commit 496cad9

Browse files
authored
chore: Introduce ANSI support for remainder operation (#1971)
1 parent aa46bd7 commit 496cad9

File tree

9 files changed

+561
-26
lines changed

9 files changed

+561
-26
lines changed

native/core/src/execution/planner.rs

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
//! Converts Spark physical plan to DataFusion physical plan
1919
20-
use super::expressions::EvalMode;
2120
use crate::execution::operators::CopyMode;
2221
use crate::execution::operators::FilterExec as CometFilterExec;
2322
use crate::{
@@ -62,8 +61,8 @@ use datafusion::{
6261
prelude::SessionContext,
6362
};
6463
use datafusion_comet_spark_expr::{
65-
create_comet_physical_fun, create_negate_expr, BloomFilterAgg, BloomFilterMightContain,
66-
SparkHour, SparkMinute, SparkSecond,
64+
create_comet_physical_fun, create_modulo_expr, create_negate_expr, BloomFilterAgg,
65+
BloomFilterMightContain, EvalMode, SparkHour, SparkMinute, SparkSecond,
6766
};
6867

6968
use crate::execution::operators::ExecutionError::GeneralError;
@@ -268,13 +267,22 @@ impl PhysicalPlanner {
268267
is_integral_div: true,
269268
},
270269
),
271-
ExprStruct::Remainder(expr) => self.create_binary_expr(
272-
expr.left.as_ref().unwrap(),
273-
expr.right.as_ref().unwrap(),
274-
expr.return_type.as_ref(),
275-
DataFusionOperator::Modulo,
276-
input_schema,
277-
),
270+
ExprStruct::Remainder(expr) => {
271+
let left =
272+
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
273+
let right =
274+
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
275+
276+
let result = create_modulo_expr(
277+
left,
278+
right,
279+
expr.return_type.as_ref().map(to_arrow_datatype).unwrap(),
280+
input_schema,
281+
expr.fail_on_error,
282+
&self.session_ctx.state(),
283+
);
284+
result.map_err(|e| GeneralError(e.to_string()))
285+
}
278286
ExprStruct::Eq(expr) => {
279287
let left =
280288
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
@@ -851,19 +859,13 @@ impl PhysicalPlanner {
851859
right.data_type(&input_schema),
852860
) {
853861
(
854-
DataFusionOperator::Plus
855-
| DataFusionOperator::Minus
856-
| DataFusionOperator::Multiply
857-
| DataFusionOperator::Modulo,
862+
DataFusionOperator::Plus | DataFusionOperator::Minus | DataFusionOperator::Multiply,
858863
Ok(DataType::Decimal128(p1, s1)),
859864
Ok(DataType::Decimal128(p2, s2)),
860865
) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus)
861866
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
862867
>= DECIMAL128_MAX_PRECISION)
863-
|| (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION)
864-
|| (op == DataFusionOperator::Modulo
865-
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
866-
> DECIMAL128_MAX_PRECISION) =>
868+
|| (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION) =>
867869
{
868870
let data_type = return_type.map(to_arrow_datatype).unwrap();
869871
// For some Decimal128 operations, we need wider internal digits.
@@ -903,6 +905,7 @@ impl PhysicalPlanner {
903905
func_name,
904906
data_type.clone(),
905907
&self.session_ctx.state(),
908+
None,
906909
)?;
907910
Ok(Arc::new(ScalarFunctionExpr::new(
908911
func_name,
@@ -2305,8 +2308,12 @@ impl PhysicalPlanner {
23052308
}
23062309
};
23072310

2308-
let fun_expr =
2309-
create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?;
2311+
let fun_expr = create_comet_physical_fun(
2312+
fun_name,
2313+
data_type.clone(),
2314+
&self.session_ctx.state(),
2315+
None,
2316+
)?;
23102317

23112318
let args = args
23122319
.into_iter()

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use crate::hash_funcs::*;
19+
use crate::math_funcs::modulo_expr::spark_modulo;
1920
use crate::{
2021
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
2122
spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
@@ -53,13 +54,23 @@ macro_rules! make_comet_scalar_udf {
5354
);
5455
Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
5556
}};
57+
($name:expr, $func:ident, without $data_type:ident, $fail_on_error:ident) => {{
58+
let scalar_func = CometScalarFunction::new(
59+
$name.to_string(),
60+
Signature::variadic_any(Volatility::Immutable),
61+
$data_type,
62+
Arc::new(move |args| $func(args, $fail_on_error)),
63+
);
64+
Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
65+
}};
5666
}
5767

5868
/// Create a physical scalar function.
5969
pub fn create_comet_physical_fun(
6070
fun_name: &str,
6171
data_type: DataType,
6272
registry: &dyn FunctionRegistry,
73+
fail_on_error: Option<bool>,
6374
) -> Result<Arc<ScalarUDF>, DataFusionError> {
6475
match fun_name {
6576
"ceil" => {
@@ -144,6 +155,11 @@ pub fn create_comet_physical_fun(
144155
let func = Arc::new(spark_array_repeat);
145156
make_comet_scalar_udf!("array_repeat", func, without data_type)
146157
}
158+
"spark_modulo" => {
159+
let func = Arc::new(spark_modulo);
160+
let fail_on_error = fail_on_error.unwrap_or(false);
161+
make_comet_scalar_udf!("spark_modulo", func, without data_type, fail_on_error)
162+
}
147163
_ => registry.udf(fun_name).map_err(|e| {
148164
DataFusionError::Execution(format!(
149165
"Function {fun_name} not found in the registry: {e}",

native/spark-expr/src/error.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ pub enum SparkError {
5151
#[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
5252
ArithmeticOverflow { from_type: String },
5353

54+
#[error("[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
55+
DivideByZero,
56+
5457
#[error("ArrowError: {0}.")]
5558
Arrow(ArrowError),
5659

native/spark-expr/src/lib.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ pub use error::{SparkError, SparkResult};
7373
pub use hash_funcs::*;
7474
pub use json_funcs::ToJson;
7575
pub use math_funcs::{
76-
create_negate_expr, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
77-
spark_hex, spark_make_decimal, spark_round, spark_unhex, spark_unscaled_value, CheckOverflow,
78-
NegativeExpr, NormalizeNaNAndZero,
76+
create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div,
77+
spark_decimal_integral_div, spark_floor, spark_hex, spark_make_decimal, spark_round,
78+
spark_unhex, spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero,
7979
};
8080
pub use string_funcs::*;
8181

@@ -103,3 +103,7 @@ pub(crate) fn arithmetic_overflow_error(from_type: &str) -> SparkError {
103103
from_type: from_type.to_string(),
104104
}
105105
}
106+
107+
pub(crate) fn divide_by_zero_error() -> SparkError {
108+
SparkError::DivideByZero
109+
}

native/spark-expr/src/math_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mod div;
2020
mod floor;
2121
pub(crate) mod hex;
2222
pub mod internal;
23+
pub mod modulo_expr;
2324
mod negative;
2425
mod round;
2526
pub(crate) mod unhex;
@@ -31,6 +32,7 @@ pub use div::spark_decimal_integral_div;
3132
pub use floor::spark_floor;
3233
pub use hex::spark_hex;
3334
pub use internal::*;
35+
pub use modulo_expr::create_modulo_expr;
3436
pub use negative::{create_negate_expr, NegativeExpr};
3537
pub use round::spark_round;
3638
pub use unhex::spark_unhex;

0 commit comments

Comments
 (0)