Skip to content

Commit 1160914

Browse files
authored
feat: Support IntegralDivide function (#1428)
## Which issue does this PR close? Closes #1422. ## Rationale for this change Support IntegralDivide function ## What changes are included in this PR? Since datafusion div operator conforms to the logic of intergal div, we only need to convert `IntegralDivide(...)` to `Cast(Divide(...), LongType)` and then convert it to native. ## How are these changes tested? added unit test
1 parent b149983 commit 1160914

File tree

10 files changed

+201
-14
lines changed

10 files changed

+201
-14
lines changed

native/core/src/execution/planner.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ struct JoinParameters {
136136
pub join_type: DFJoinType,
137137
}
138138

139+
#[derive(Default)]
140+
struct BinaryExprOptions {
141+
pub is_integral_div: bool,
142+
}
143+
139144
pub const TEST_EXEC_CONTEXT_ID: i64 = -1;
140145

141146
/// The query planner for converting Spark query plans to DataFusion query plans.
@@ -211,6 +216,16 @@ impl PhysicalPlanner {
211216
DataFusionOperator::Divide,
212217
input_schema,
213218
),
219+
ExprStruct::IntegralDivide(expr) => self.create_binary_expr_with_options(
220+
expr.left.as_ref().unwrap(),
221+
expr.right.as_ref().unwrap(),
222+
expr.return_type.as_ref(),
223+
DataFusionOperator::Divide,
224+
input_schema,
225+
BinaryExprOptions {
226+
is_integral_div: true,
227+
},
228+
),
214229
ExprStruct::Remainder(expr) => self.create_binary_expr(
215230
expr.left.as_ref().unwrap(),
216231
expr.right.as_ref().unwrap(),
@@ -873,6 +888,25 @@ impl PhysicalPlanner {
873888
return_type: Option<&spark_expression::DataType>,
874889
op: DataFusionOperator,
875890
input_schema: SchemaRef,
891+
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
892+
self.create_binary_expr_with_options(
893+
left,
894+
right,
895+
return_type,
896+
op,
897+
input_schema,
898+
BinaryExprOptions::default(),
899+
)
900+
}
901+
902+
fn create_binary_expr_with_options(
903+
&self,
904+
left: &Expr,
905+
right: &Expr,
906+
return_type: Option<&spark_expression::DataType>,
907+
op: DataFusionOperator,
908+
input_schema: SchemaRef,
909+
options: BinaryExprOptions,
876910
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
877911
let left = self.create_expr(left, Arc::clone(&input_schema))?;
878912
let right = self.create_expr(right, Arc::clone(&input_schema))?;
@@ -922,13 +956,21 @@ impl PhysicalPlanner {
922956
Ok(DataType::Decimal128(_p2, _s2)),
923957
) => {
924958
let data_type = return_type.map(to_arrow_datatype).unwrap();
959+
let func_name = if options.is_integral_div {
960+
// Decimal256 division in Arrow may overflow, so we still need this variant of decimal_div.
961+
// Otherwise, we may be able to reuse the previous case-match instead of here,
962+
// see more: https://github.com/apache/datafusion-comet/pull/1428#discussion_r1972648463
963+
"decimal_integral_div"
964+
} else {
965+
"decimal_div"
966+
};
925967
let fun_expr = create_comet_physical_fun(
926-
"decimal_div",
968+
func_name,
927969
data_type.clone(),
928970
&self.session_ctx.state(),
929971
)?;
930972
Ok(Arc::new(ScalarFunctionExpr::new(
931-
"decimal_div",
973+
func_name,
932974
fun_expr,
933975
vec![left, right],
934976
data_type,

native/proto/src/proto/expr.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ message Expr {
8989
BinaryExpr array_intersect = 62;
9090
ArrayJoin array_join = 63;
9191
BinaryExpr arrays_overlap = 64;
92+
MathExpr integral_divide = 65;
9293
}
9394
}
9495

native/spark-expr/benches/decimal_div.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use arrow::compute::cast;
1919
use arrow_array::builder::Decimal128Builder;
2020
use arrow_schema::DataType;
2121
use criterion::{black_box, criterion_group, criterion_main, Criterion};
22-
use datafusion_comet_spark_expr::spark_decimal_div;
22+
use datafusion_comet_spark_expr::{spark_decimal_div, spark_decimal_integral_div};
2323
use datafusion_expr::ColumnarValue;
2424
use std::sync::Arc;
2525

@@ -40,14 +40,25 @@ fn criterion_benchmark(c: &mut Criterion) {
4040
let c2 = cast(c2.as_ref(), &c2_type).unwrap();
4141

4242
let args = [ColumnarValue::Array(c1), ColumnarValue::Array(c2)];
43-
c.bench_function("decimal_div", |b| {
43+
44+
let mut group = c.benchmark_group("decimal div");
45+
group.bench_function("decimal_div", |b| {
4446
b.iter(|| {
4547
black_box(spark_decimal_div(
4648
black_box(&args),
4749
black_box(&DataType::Decimal128(10, 4)),
4850
))
4951
})
5052
});
53+
54+
group.bench_function("decimal_integral_div", |b| {
55+
b.iter(|| {
56+
black_box(spark_decimal_integral_div(
57+
black_box(&args),
58+
black_box(&DataType::Decimal128(10, 4)),
59+
))
60+
})
61+
});
5162
}
5263

5364
criterion_group!(benches, criterion_benchmark);

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
use crate::hash_funcs::*;
1919
use crate::{
20-
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex,
21-
spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_unhex,
22-
spark_unscaled_value, SparkChrFunc,
20+
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_decimal_integral_div,
21+
spark_floor, spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round,
22+
spark_unhex, spark_unscaled_value, SparkChrFunc,
2323
};
2424
use arrow_schema::DataType;
2525
use datafusion_common::{DataFusionError, Result as DataFusionResult};
@@ -90,6 +90,13 @@ pub fn create_comet_physical_fun(
9090
"decimal_div" => {
9191
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
9292
}
93+
"decimal_integral_div" => {
94+
make_comet_scalar_udf!(
95+
"decimal_integral_div",
96+
spark_decimal_integral_div,
97+
data_type
98+
)
99+
}
93100
"murmur3_hash" => {
94101
let func = Arc::new(spark_murmur3_hash);
95102
make_comet_scalar_udf!("murmur3_hash", func, without data_type)

native/spark-expr/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ pub use error::{SparkError, SparkResult};
6767
pub use hash_funcs::*;
6868
pub use json_funcs::ToJson;
6969
pub use math_funcs::{
70-
create_negate_expr, spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_make_decimal,
71-
spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, NegativeExpr,
72-
NormalizeNaNAndZero,
70+
create_negate_expr, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
71+
spark_hex, spark_make_decimal, spark_round, spark_unhex, spark_unscaled_value, CheckOverflow,
72+
NegativeExpr, NormalizeNaNAndZero,
7373
};
7474
pub use string_funcs::*;
7575

native/spark-expr/src/math_funcs/div.rs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,29 @@ use datafusion_common::DataFusionError;
2727
use num::{BigInt, Signed, ToPrimitive};
2828
use std::sync::Arc;
2929

30+
pub fn spark_decimal_div(
31+
args: &[ColumnarValue],
32+
data_type: &DataType,
33+
) -> Result<ColumnarValue, DataFusionError> {
34+
spark_decimal_div_internal(args, data_type, false)
35+
}
36+
37+
pub fn spark_decimal_integral_div(
38+
args: &[ColumnarValue],
39+
data_type: &DataType,
40+
) -> Result<ColumnarValue, DataFusionError> {
41+
spark_decimal_div_internal(args, data_type, true)
42+
}
43+
3044
// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
3145
// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to
3246
// get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since
3347
// both s2 and s3 are 38 at max., s1 is 77 at max. DataFusion division cannot handle such scale >
3448
// Decimal256Type::MAX_SCALE. Therefore, we need to implement this decimal division using BigInt.
35-
pub fn spark_decimal_div(
49+
fn spark_decimal_div_internal(
3650
args: &[ColumnarValue],
3751
data_type: &DataType,
52+
is_integral_div: bool,
3853
) -> Result<ColumnarValue, DataFusionError> {
3954
let left = &args[0];
4055
let right = &args[1];
@@ -69,7 +84,9 @@ pub fn spark_decimal_div(
6984
let l = BigInt::from(l) * &l_mul;
7085
let r = BigInt::from(r) * &r_mul;
7186
let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
72-
let res = if div.is_negative() {
87+
let res = if is_integral_div {
88+
div
89+
} else if div.is_negative() {
7390
div - &five
7491
} else {
7592
div + &five
@@ -83,7 +100,13 @@ pub fn spark_decimal_div(
83100
let l = l * l_mul;
84101
let r = r * r_mul;
85102
let div = if r == 0 { 0 } else { l / r };
86-
let res = if div.is_negative() { div - 5 } else { div + 5 } / 10;
103+
let res = if is_integral_div {
104+
div
105+
} else if div.is_negative() {
106+
div - 5
107+
} else {
108+
div + 5
109+
} / 10;
87110
res.to_i128().unwrap_or(i128::MAX)
88111
})?
89112
};

native/spark-expr/src/math_funcs/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ mod utils;
2727

2828
pub use ceil::spark_ceil;
2929
pub use div::spark_decimal_div;
30+
pub use div::spark_decimal_integral_div;
3031
pub use floor::spark_floor;
3132
pub use hex::spark_hex;
3233
pub use internal::*;

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.apache.comet.serde
2121

2222
import scala.collection.JavaConverters._
23+
import scala.math.min
2324

2425
import org.apache.spark.internal.Logging
2526
import org.apache.spark.sql.catalyst.expressions._
@@ -631,6 +632,44 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
631632
}
632633
None
633634

635+
case div @ IntegralDivide(left, right, _)
636+
if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) =>
637+
val rightExpr = nullIfWhenPrimitive(right)
638+
639+
val dataType = (left.dataType, right.dataType) match {
640+
case (l: DecimalType, r: DecimalType) =>
641+
// copy from IntegralDivide.resultDecimalType
642+
val intDig = l.precision - l.scale + r.scale
643+
DecimalType(min(if (intDig == 0) 1 else intDig, DecimalType.MAX_PRECISION), 0)
644+
case _ => left.dataType
645+
}
646+
647+
val divideExpr = createMathExpression(
648+
expr,
649+
left,
650+
rightExpr,
651+
inputs,
652+
binding,
653+
dataType,
654+
getFailOnError(div),
655+
(builder, mathExpr) => builder.setIntegralDivide(mathExpr))
656+
657+
if (divideExpr.isDefined) {
658+
// cast result to long
659+
castToProto(expr, None, LongType, divideExpr.get, CometEvalMode.LEGACY)
660+
} else {
661+
None
662+
}
663+
664+
case div @ IntegralDivide(left, _, _) =>
665+
if (!supportedDataType(left.dataType)) {
666+
withInfo(div, s"Unsupported datatype ${left.dataType}")
667+
}
668+
if (decimalBeforeSpark34(left.dataType)) {
669+
withInfo(div, "Decimal support requires Spark 3.4 or later")
670+
}
671+
None
672+
634673
case rem @ Remainder(left, right, _)
635674
if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) =>
636675
val rightExpr = nullIfWhenPrimitive(right)

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1818,7 +1818,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
18181818
"spark.sql.decimalOperations.allowPrecisionLoss" -> allowPrecisionLoss.toString) {
18191819
val a = makeNum(p1, s1)
18201820
val b = makeNum(p2, s2)
1821-
var ops = Seq("+", "-", "*", "/", "%")
1821+
val ops = Seq("+", "-", "*", "/", "%", "div")
18221822
for (op <- ops) {
18231823
checkSparkAnswerAndOperator(s"select a, b, a $op b from $table")
18241824
checkSparkAnswerAndOperator(s"select $a, b, $a $op b from $table")
@@ -2648,4 +2648,57 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
26482648
}
26492649
}
26502650

2651+
test("test integral divide") {
2652+
Seq(true, false).foreach { dictionaryEnabled =>
2653+
withTempDir { dir =>
2654+
val path1 = new Path(dir.toURI.toString, "test1.parquet")
2655+
val path2 = new Path(dir.toURI.toString, "test2.parquet")
2656+
makeParquetFileAllTypes(
2657+
path1,
2658+
dictionaryEnabled = dictionaryEnabled,
2659+
0,
2660+
0,
2661+
randomSize = 10000)
2662+
makeParquetFileAllTypes(
2663+
path2,
2664+
dictionaryEnabled = dictionaryEnabled,
2665+
0,
2666+
0,
2667+
randomSize = 10000)
2668+
withParquetTable(path1.toString, "tbl1") {
2669+
withParquetTable(path2.toString, "tbl2") {
2670+
// disable broadcast, as comet on spark 3.3 does not support broadcast exchange
2671+
withSQLConf(
2672+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
2673+
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
2674+
checkSparkAnswerAndOperator("""
2675+
|select
2676+
| t1._2 div t2._2, div(t1._2, t2._2),
2677+
| t1._3 div t2._3, div(t1._3, t2._3),
2678+
| t1._4 div t2._4, div(t1._4, t2._4),
2679+
| t1._5 div t2._5, div(t1._5, t2._5),
2680+
| t1._9 div t2._9, div(t1._9, t2._9),
2681+
| t1._10 div t2._10, div(t1._10, t2._10),
2682+
| t1._11 div t2._11, div(t1._11, t2._11)
2683+
| from tbl1 t1 join tbl2 t2 on t1._id = t2._id
2684+
| order by t1._id""".stripMargin)
2685+
2686+
if (isSpark34Plus) {
2687+
// decimal support requires Spark 3.4 or later
2688+
checkSparkAnswerAndOperator("""
2689+
|select
2690+
| t1._12 div t2._12, div(t1._12, t2._12),
2691+
| t1._15 div t2._15, div(t1._15, t2._15),
2692+
| t1._16 div t2._16, div(t1._16, t2._16),
2693+
| t1._17 div t2._17, div(t1._17, t2._17)
2694+
| from tbl1 t1 join tbl2 t2 on t1._id = t2._id
2695+
| order by t1._id""".stripMargin)
2696+
}
2697+
}
2698+
}
2699+
}
2700+
}
2701+
}
2702+
}
2703+
26512704
}

0 commit comments

Comments
 (0)