Skip to content

Commit 040aa08

Browse files
committed
impl_ansi
1 parent f68203f commit 040aa08

File tree

4 files changed

+19
-16
lines changed

4 files changed

+19
-16
lines changed

native/core/src/execution/planner.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1086,7 +1086,8 @@ impl PhysicalPlanner {
10861086
}
10871087
_ => {
10881088
let data_type = return_type.map(to_arrow_datatype).unwrap();
1089-
if [EvalMode::Try, EvalMode::Ansi].contains(&eval_mode) && (data_type.is_numeric())
1089+
if [EvalMode::Try, EvalMode::Ansi].contains(&eval_mode)
1090+
&& (data_type.is_integer() || data_type.is_floating())
10901091
{
10911092
let op_str = match op {
10921093
DataFusionOperator::Plus => "checked_add",

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,12 @@ pub fn create_comet_physical_fun_with_eval_mode(
136136
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
137137
}
138138
"decimal_integral_div" => {
139+
let is_ansi_mode = eval_mode == EvalMode::Ansi;
139140
make_comet_scalar_udf!(
140141
"decimal_integral_div",
141142
spark_decimal_integral_div,
142-
data_type
143+
data_type,
144+
is_ansi_mode
143145
)
144146
}
145147
"checked_add" => {

native/spark-expr/src/math_funcs/div.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,33 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::divide_by_zero_error;
1819
use crate::math_funcs::utils::get_precision_scale;
1920
use arrow::array::{Array, Decimal128Array};
2021
use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION};
22+
use arrow::error::ArrowError;
2123
use arrow::{
2224
array::{ArrayRef, AsArray},
2325
datatypes::Decimal128Type,
2426
};
2527
use datafusion::common::DataFusionError;
2628
use datafusion::physical_plan::ColumnarValue;
27-
use num::{BigInt, Signed, ToPrimitive};
29+
use num::{BigInt, Signed, ToPrimitive, Zero};
2830
use std::sync::Arc;
2931

3032
pub fn spark_decimal_div(
3133
args: &[ColumnarValue],
3234
data_type: &DataType,
3335
) -> Result<ColumnarValue, DataFusionError> {
34-
spark_decimal_div_internal(args, data_type, false)
36+
spark_decimal_div_internal(args, data_type, false, false)
3537
}
3638

3739
pub fn spark_decimal_integral_div(
3840
args: &[ColumnarValue],
3941
data_type: &DataType,
42+
is_ansi_enabled: bool,
4043
) -> Result<ColumnarValue, DataFusionError> {
41-
spark_decimal_div_internal(args, data_type, true)
44+
spark_decimal_div_internal(args, data_type, true, is_ansi_enabled)
4245
}
4346

4447
// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
@@ -50,6 +53,7 @@ fn spark_decimal_div_internal(
5053
args: &[ColumnarValue],
5154
data_type: &DataType,
5255
is_integral_div: bool,
56+
is_ansi_enabled: bool,
5357
) -> Result<ColumnarValue, DataFusionError> {
5458
let left = &args[0];
5559
let right = &args[1];
@@ -96,9 +100,12 @@ fn spark_decimal_div_internal(
96100
} else {
97101
let l_mul = 10_i128.pow(l_exp);
98102
let r_mul = 10_i128.pow(r_exp);
99-
arrow::compute::kernels::arity::binary(left, right, |l, r| {
103+
arrow::compute::kernels::arity::try_binary(left, right, |l, r| {
100104
let l = l * l_mul;
101105
let r = r * r_mul;
106+
if is_ansi_enabled && r.is_zero() {
107+
return Err(ArrowError::ComputeError(divide_by_zero_error().to_string()));
108+
}
102109
let div = if r == 0 { 0 } else { l / r };
103110
let res = if is_integral_div {
104111
div
@@ -107,7 +114,7 @@ fn spark_decimal_div_internal(
107114
} else {
108115
div + 5
109116
} / 10;
110-
res.to_i128().unwrap_or(i128::MAX)
117+
Ok(res.to_i128().unwrap_or(i128::MAX))
111118
})?
112119
};
113120
let result = result.with_data_type(DataType::Decimal128(p3, s3));

spark/src/main/scala/org/apache/comet/serde/arithmetic.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,6 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase {
180180

181181
object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with MathBase {
182182

183-
override def getSupportLevel(expr: IntegralDivide): SupportLevel = {
184-
if (expr.evalMode == EvalMode.ANSI) {
185-
Incompatible(Some("ANSI mode is not supported"))
186-
} else {
187-
Compatible(None)
188-
}
189-
}
190-
191183
override def convert(
192184
expr: IntegralDivide,
193185
inputs: Seq[Attribute],
@@ -206,7 +198,8 @@ object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with Mat
206198
if (expr.right.dataType.isInstanceOf[DecimalType]) expr.right
207199
else Cast(expr.right, DecimalType(19, 0))
208200

209-
val rightExpr = nullIfWhenPrimitive(expr.right)
201+
val rightExpr =
202+
if (expr.evalMode != EvalMode.ANSI) nullIfWhenPrimitive(expr.right) else expr.right
210203

211204
val dataType = (left.dataType, right.dataType) match {
212205
case (l: DecimalType, r: DecimalType) =>

0 commit comments

Comments
 (0)