Skip to content

Commit acfd03c

Browse files
authored
feat:support ansi mode rounding function (#2542)
1 parent c214049 commit acfd03c

File tree

15 files changed

+148
-50
lines changed

15 files changed

+148
-50
lines changed

docs/source/user-guide/latest/expressions.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ incompatible expressions.
116116
## Math Expressions
117117

118118
| Expression | SQL | Spark-Compatible? | Compatibility Notes |
119-
| -------------- | --------- | ----------------- | --------------------------------- |
119+
| -------------- | --------- | ----------------- |-----------------------------------|
120120
| Acos | `acos` | Yes | |
121121
| Add | `+` | Yes | |
122122
| Asin | `asin` | Yes | |
@@ -140,7 +140,7 @@ incompatible expressions.
140140
| Rand | `rand` | Yes | |
141141
| Randn | `randn` | Yes | |
142142
| Remainder | `%` | Yes | ANSI mode is not supported. |
143-
| Round | `round` | Yes | ANSI mode is not supported. |
143+
| Round | `round` | Yes | |
144144
| Signum | `signum` | Yes | |
145145
| Sin | `sin` | Yes | |
146146
| Sqrt | `sqrt` | Yes | |

native/core/src/execution/planner.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2478,7 +2478,7 @@ impl PhysicalPlanner {
24782478
fun_name,
24792479
data_type.clone(),
24802480
&self.session_ctx.state(),
2481-
None,
2481+
Some(expr.fail_on_error),
24822482
)?;
24832483

24842484
let args = args
@@ -3346,6 +3346,7 @@ mod tests {
33463346
func: "make_array".to_string(),
33473347
args: vec![array_col, array_col_1],
33483348
return_type: None,
3349+
fail_on_error: false,
33493350
})),
33503351
}],
33513352
})),
@@ -3464,6 +3465,7 @@ mod tests {
34643465
func: "array_repeat".to_string(),
34653466
args: vec![array_col, array_col_1],
34663467
return_type: None,
3468+
fail_on_error: false,
34673469
})),
34683470
}],
34693471
})),

native/proto/src/proto/expr.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ message ScalarFunc {
308308
string func = 1;
309309
repeated Expr args = 2;
310310
DataType return_type = 3;
311+
bool fail_on_error = 4;
311312
}
312313

313314
message CaseWhen {

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ pub fn create_comet_physical_fun_with_eval_mode(
9999
fail_on_error: Option<bool>,
100100
eval_mode: EvalMode,
101101
) -> Result<Arc<ScalarUDF>, DataFusionError> {
102+
let fail_on_error = fail_on_error.unwrap_or(false);
102103
match fun_name {
103104
"ceil" => {
104105
make_comet_scalar_udf!("ceil", spark_ceil, data_type)
@@ -119,7 +120,7 @@ pub fn create_comet_physical_fun_with_eval_mode(
119120
make_comet_scalar_udf!("lpad", func, without data_type)
120121
}
121122
"round" => {
122-
make_comet_scalar_udf!("round", spark_round, data_type)
123+
make_comet_scalar_udf!("round", spark_round, data_type, fail_on_error)
123124
}
124125
"unscaled_value" => {
125126
let func = Arc::new(spark_unscaled_value);
@@ -177,7 +178,6 @@ pub fn create_comet_physical_fun_with_eval_mode(
177178
}
178179
"spark_modulo" => {
179180
let func = Arc::new(spark_modulo);
180-
let fail_on_error = fail_on_error.unwrap_or(false);
181181
make_comet_scalar_udf!("spark_modulo", func, without data_type, fail_on_error)
182182
}
183183
_ => registry.udf(fun_name).map_err(|e| {

native/spark-expr/src/math_funcs/round.rs

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,49 +15,81 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::arithmetic_overflow_error;
1819
use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
1920
use arrow::array::{Array, ArrowNativeTypeOp};
2021
use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array};
2122
use arrow::datatypes::DataType;
23+
use arrow::error::ArrowError;
2224
use datafusion::common::{exec_err, internal_err, DataFusionError, ScalarValue};
2325
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
2426
use std::{cmp::min, sync::Arc};
2527

2628
macro_rules! integer_round {
27-
($X:expr, $DIV:expr, $HALF:expr) => {{
29+
($X:expr, $DIV:expr, $HALF:expr, $FAIL_ON_ERROR:expr) => {{
2830
let rem = $X % $DIV;
2931
if rem <= -$HALF {
30-
($X - rem).sub_wrapping($DIV)
32+
if $FAIL_ON_ERROR {
33+
($X - rem).sub_checked($DIV).map_err(|_| {
34+
ArrowError::ComputeError(arithmetic_overflow_error("integer").to_string())
35+
})
36+
} else {
37+
Ok(($X - rem).sub_wrapping($DIV))
38+
}
3139
} else if rem >= $HALF {
32-
($X - rem).add_wrapping($DIV)
40+
if $FAIL_ON_ERROR {
41+
($X - rem).add_checked($DIV).map_err(|_| {
42+
ArrowError::ComputeError(arithmetic_overflow_error("integer").to_string())
43+
})
44+
} else {
45+
Ok(($X - rem).add_wrapping($DIV))
46+
}
3347
} else {
34-
$X - rem
48+
if $FAIL_ON_ERROR {
49+
$X.sub_checked(rem).map_err(|_| {
50+
ArrowError::ComputeError(arithmetic_overflow_error("integer").to_string())
51+
})
52+
} else {
53+
Ok($X.sub_wrapping(rem))
54+
}
3555
}
3656
}};
3757
}
3858

3959
macro_rules! round_integer_array {
40-
($ARRAY:expr, $POINT:expr, $TYPE:ty, $NATIVE:ty) => {{
60+
($ARRAY:expr, $POINT:expr, $TYPE:ty, $NATIVE:ty, $FAIL_ON_ERROR:expr) => {{
4161
let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap();
4262
let ten: $NATIVE = 10;
4363
let result: $TYPE = if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
4464
let half = div / 2;
45-
arrow::compute::kernels::arity::unary(array, |x| integer_round!(x, div, half))
65+
arrow::compute::kernels::arity::try_unary(array, |x| {
66+
integer_round!(x, div, half, $FAIL_ON_ERROR)
67+
})?
4668
} else {
47-
arrow::compute::kernels::arity::unary(array, |_| 0)
69+
arrow::compute::kernels::arity::try_unary(array, |_| Ok(0))?
4870
};
4971
Ok(ColumnarValue::Array(Arc::new(result)))
5072
}};
5173
}
5274

5375
macro_rules! round_integer_scalar {
54-
($SCALAR:expr, $POINT:expr, $TYPE:expr, $NATIVE:ty) => {{
76+
($SCALAR:expr, $POINT:expr, $TYPE:expr, $NATIVE:ty, $FAIL_ON_ERROR:expr) => {{
5577
let ten: $NATIVE = 10;
5678
if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) {
5779
let half = div / 2;
58-
Ok(ColumnarValue::Scalar($TYPE(
59-
$SCALAR.map(|x| integer_round!(x, div, half)),
60-
)))
80+
let scalar_opt = match $SCALAR {
81+
Some(x) => match integer_round!(x, div, half, $FAIL_ON_ERROR) {
82+
Ok(v) => Some(v),
83+
Err(e) => {
84+
return Err(DataFusionError::ArrowError(
85+
Box::from(e),
86+
Some(DataFusionError::get_back_trace()),
87+
))
88+
}
89+
},
90+
None => None,
91+
};
92+
Ok(ColumnarValue::Scalar($TYPE(scalar_opt)))
6193
} else {
6294
Ok(ColumnarValue::Scalar($TYPE(Some(0))))
6395
}
@@ -68,6 +100,7 @@ macro_rules! round_integer_scalar {
68100
pub fn spark_round(
69101
args: &[ColumnarValue],
70102
data_type: &DataType,
103+
fail_on_error: bool,
71104
) -> Result<ColumnarValue, DataFusionError> {
72105
let value = &args[0];
73106
let point = &args[1];
@@ -76,10 +109,18 @@ pub fn spark_round(
76109
};
77110
match value {
78111
ColumnarValue::Array(array) => match array.data_type() {
79-
DataType::Int64 if *point < 0 => round_integer_array!(array, point, Int64Array, i64),
80-
DataType::Int32 if *point < 0 => round_integer_array!(array, point, Int32Array, i32),
81-
DataType::Int16 if *point < 0 => round_integer_array!(array, point, Int16Array, i16),
82-
DataType::Int8 if *point < 0 => round_integer_array!(array, point, Int8Array, i8),
112+
DataType::Int64 if *point < 0 => {
113+
round_integer_array!(array, point, Int64Array, i64, fail_on_error)
114+
}
115+
DataType::Int32 if *point < 0 => {
116+
round_integer_array!(array, point, Int32Array, i32, fail_on_error)
117+
}
118+
DataType::Int16 if *point < 0 => {
119+
round_integer_array!(array, point, Int16Array, i16, fail_on_error)
120+
}
121+
DataType::Int8 if *point < 0 => {
122+
round_integer_array!(array, point, Int8Array, i8, fail_on_error)
123+
}
83124
DataType::Decimal128(_, scale) if *scale >= 0 => {
84125
let f = decimal_round_f(scale, point);
85126
let (precision, scale) = get_precision_scale(data_type);
@@ -93,16 +134,16 @@ pub fn spark_round(
93134
},
94135
ColumnarValue::Scalar(a) => match a {
95136
ScalarValue::Int64(a) if *point < 0 => {
96-
round_integer_scalar!(a, point, ScalarValue::Int64, i64)
137+
round_integer_scalar!(a, point, ScalarValue::Int64, i64, fail_on_error)
97138
}
98139
ScalarValue::Int32(a) if *point < 0 => {
99-
round_integer_scalar!(a, point, ScalarValue::Int32, i32)
140+
round_integer_scalar!(a, point, ScalarValue::Int32, i32, fail_on_error)
100141
}
101142
ScalarValue::Int16(a) if *point < 0 => {
102-
round_integer_scalar!(a, point, ScalarValue::Int16, i16)
143+
round_integer_scalar!(a, point, ScalarValue::Int16, i16, fail_on_error)
103144
}
104145
ScalarValue::Int8(a) if *point < 0 => {
105-
round_integer_scalar!(a, point, ScalarValue::Int8, i8)
146+
round_integer_scalar!(a, point, ScalarValue::Int8, i8, fail_on_error)
106147
}
107148
ScalarValue::Decimal128(a, _, scale) if *scale >= 0 => {
108149
let f = decimal_round_f(scale, point);
@@ -158,7 +199,7 @@ mod test {
158199
]))),
159200
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
160201
];
161-
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float32)? else {
202+
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float32, false)? else {
162203
unreachable!()
163204
};
164205
let floats = as_float32_array(&result)?;
@@ -176,7 +217,7 @@ mod test {
176217
]))),
177218
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
178219
];
179-
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float64)? else {
220+
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float64, false)? else {
180221
unreachable!()
181222
};
182223
let floats = as_float64_array(&result)?;
@@ -193,7 +234,7 @@ mod test {
193234
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
194235
];
195236
let ColumnarValue::Scalar(ScalarValue::Float32(Some(result))) =
196-
spark_round(&args, &DataType::Float32)?
237+
spark_round(&args, &DataType::Float32, false)?
197238
else {
198239
unreachable!()
199240
};
@@ -209,7 +250,7 @@ mod test {
209250
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
210251
];
211252
let ColumnarValue::Scalar(ScalarValue::Float64(Some(result))) =
212-
spark_round(&args, &DataType::Float64)?
253+
spark_round(&args, &DataType::Float64, false)?
213254
else {
214255
unreachable!()
215256
};

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -851,14 +851,15 @@ object QueryPlanSerde extends Logging with CometExprShim {
851851
case UnscaledValue(child) =>
852852
val childExpr = exprToProtoInternal(child, inputs, binding)
853853
val optExpr =
854-
scalarFunctionExprToProtoWithReturnType("unscaled_value", LongType, childExpr)
854+
scalarFunctionExprToProtoWithReturnType("unscaled_value", LongType, false, childExpr)
855855
optExprWithInfo(optExpr, expr, child)
856856

857857
case MakeDecimal(child, precision, scale, true) =>
858858
val childExpr = exprToProtoInternal(child, inputs, binding)
859859
val optExpr = scalarFunctionExprToProtoWithReturnType(
860860
"make_decimal",
861861
DecimalType(precision, scale),
862+
false,
862863
childExpr)
863864
optExprWithInfo(optExpr, expr, child)
864865

@@ -967,9 +968,11 @@ object QueryPlanSerde extends Logging with CometExprShim {
967968
def scalarFunctionExprToProtoWithReturnType(
968969
funcName: String,
969970
returnType: DataType,
971+
failOnError: Boolean,
970972
args: Option[Expr]*): Option[Expr] = {
971973
val builder = ExprOuterClass.ScalarFunc.newBuilder()
972974
builder.setFunc(funcName)
975+
builder.setFailOnError(failOnError)
973976
serializeDataType(returnType).flatMap { t =>
974977
builder.setReturnType(t)
975978
scalarFunctionExprToProto0(builder, args: _*)
@@ -979,6 +982,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
979982
def scalarFunctionExprToProto(funcName: String, args: Option[Expr]*): Option[Expr] = {
980983
val builder = ExprOuterClass.ScalarFunc.newBuilder()
981984
builder.setFunc(funcName)
985+
builder.setFailOnError(false)
982986
scalarFunctionExprToProto0(builder, args: _*)
983987
}
984988

spark/src/main/scala/org/apache/comet/serde/arithmetic.scala

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,6 @@ object CometRemainder extends CometExpressionSerde[Remainder] with MathBase {
279279

280280
object CometRound extends CometExpressionSerde[Round] {
281281

282-
override def getSupportLevel(expr: Round): SupportLevel = {
283-
if (expr.ansiEnabled) {
284-
Incompatible(Some("ANSI mode is not supported"))
285-
} else {
286-
Compatible(None)
287-
}
288-
}
289-
290282
override def convert(
291283
r: Round,
292284
inputs: Seq[Attribute],
@@ -325,7 +317,12 @@ object CometRound extends CometExpressionSerde[Round] {
325317
// `scale` must be Int64 type in DataFusion
326318
val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs, binding)
327319
val optExpr =
328-
scalarFunctionExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr)
320+
scalarFunctionExprToProtoWithReturnType(
321+
"round",
322+
r.dataType,
323+
r.ansiEnabled,
324+
childExpr,
325+
scaleExpr)
329326
optExprWithInfo(optExpr, r, r.child)
330327
}
331328

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ object CometArraysOverlap extends CometExpressionSerde[ArraysOverlap] {
210210
val arraysOverlapScalarExpr = scalarFunctionExprToProtoWithReturnType(
211211
"array_has_any",
212212
BooleanType,
213+
false,
213214
leftArrayExprProto,
214215
rightArrayExprProto)
215216
optExprWithInfo(arraysOverlapScalarExpr, expr, expr.children: _*)
@@ -250,6 +251,7 @@ object CometArrayCompact extends CometExpressionSerde[Expression] {
250251
val arrayCompactScalarExpr = scalarFunctionExprToProtoWithReturnType(
251252
"array_remove_all",
252253
ArrayType(elementType = elementType),
254+
false,
253255
arrayExprProto,
254256
nullLiteralProto)
255257
optExprWithInfo(arrayCompactScalarExpr, expr, expr.children: _*)

spark/src/main/scala/org/apache/comet/serde/bitwise.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ object CometBitwiseGet extends CometExpressionSerde[BitwiseGet] {
135135
val argProto = exprToProto(expr.left, inputs, binding)
136136
val posProto = exprToProto(expr.right, inputs, binding)
137137
val bitGetScalarExpr =
138-
scalarFunctionExprToProtoWithReturnType("bit_get", ByteType, argProto, posProto)
138+
scalarFunctionExprToProtoWithReturnType("bit_get", ByteType, false, argProto, posProto)
139139
optExprWithInfo(bitGetScalarExpr, expr, expr.children: _*)
140140
}
141141
}
@@ -147,7 +147,7 @@ object CometBitwiseCount extends CometExpressionSerde[BitwiseCount] {
147147
binding: Boolean): Option[ExprOuterClass.Expr] = {
148148
val childProto = exprToProto(expr.child, inputs, binding)
149149
val bitCountScalarExpr =
150-
scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, childProto)
150+
scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, false, childProto)
151151
optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*)
152152
}
153153
}

spark/src/main/scala/org/apache/comet/serde/datetime.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,12 @@ object CometTruncDate extends CometExpressionSerde[TruncDate] {
263263
val childExpr = exprToProtoInternal(expr.date, inputs, binding)
264264
val formatExpr = exprToProtoInternal(expr.format, inputs, binding)
265265
val optExpr =
266-
scalarFunctionExprToProtoWithReturnType("date_trunc", DateType, childExpr, formatExpr)
266+
scalarFunctionExprToProtoWithReturnType(
267+
"date_trunc",
268+
DateType,
269+
false,
270+
childExpr,
271+
formatExpr)
267272
optExprWithInfo(optExpr, expr, expr.date, expr.format)
268273
}
269274
}

0 commit comments

Comments
 (0)