Skip to content

Commit 03c0626

Browse files
authored
feat: implement_ansi_eval_mode_arithmetic (#2136)
1 parent 5845227 commit 03c0626

File tree

8 files changed

+357
-66
lines changed

8 files changed

+357
-66
lines changed

dev/diffs/3.4.3.diff

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,19 @@ index 41fd4de2a09..44cd244d3b0 100644
193193
-- Test aggregate operator with codegen on and off.
194194
--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
195195
--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
196+
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
197+
index 3a409eea348..38fed024c98 100644
198+
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
199+
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
200+
@@ -69,6 +69,8 @@ SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1
201+
-- any evens
202+
SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0');
203+
204+
+-- https://github.com/apache/datafusion-comet/issues/2215
205+
+--SET spark.comet.exec.enabled=false
206+
-- [SPARK-28024] Incorrect value when out of range
207+
SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i;
208+
196209
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
197210
index fac23b4a26f..2b73732c33f 100644
198211
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
@@ -881,7 +894,7 @@ index b5b34922694..a72403780c4 100644
881894
protected val baseResourcePath = {
882895
// use the same way as `SQLQueryTestSuite` to get the resource path
883896
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
884-
index 525d97e4998..5e04319dd97 100644
897+
index 525d97e4998..843f0472c23 100644
885898
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
886899
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
887900
@@ -1508,7 +1508,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
@@ -894,7 +907,27 @@ index 525d97e4998..5e04319dd97 100644
894907
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") {
895908
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
896909
}
897-
@@ -4467,7 +4468,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
910+
@@ -4429,7 +4430,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
911+
}
912+
913+
test("SPARK-39166: Query context of binary arithmetic should be serialized to executors" +
914+
- " when WSCG is off") {
915+
+ " when WSCG is off",
916+
+ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
917+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
918+
SQLConf.ANSI_ENABLED.key -> "true") {
919+
withTable("t") {
920+
@@ -4450,7 +4452,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
921+
}
922+
923+
test("SPARK-39175: Query context of Cast should be serialized to executors" +
924+
- " when WSCG is off") {
925+
+ " when WSCG is off",
926+
+ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
927+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
928+
SQLConf.ANSI_ENABLED.key -> "true") {
929+
withTable("t") {
930+
@@ -4467,14 +4470,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
898931
val msg = intercept[SparkException] {
899932
sql(query).collect()
900933
}.getMessage
@@ -907,6 +940,15 @@ index 525d97e4998..5e04319dd97 100644
907940
}
908941
}
909942
}
943+
}
944+
945+
test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " +
946+
- "be serialized to executors when WSCG is off") {
947+
+ "be serialized to executors when WSCG is off",
948+
+ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
949+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
950+
SQLConf.ANSI_ENABLED.key -> "true") {
951+
withTable("t") {
910952
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
911953
index 48ad10992c5..51d1ee65422 100644
912954
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala

dev/diffs/3.5.6.diff

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,19 @@ index 41fd4de2a09..44cd244d3b0 100644
172172
-- Test aggregate operator with codegen on and off.
173173
--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
174174
--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
175+
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
176+
index 3a409eea348..38fed024c98 100644
177+
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
178+
+++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int4.sql
179+
@@ -69,6 +69,8 @@ SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1
180+
-- any evens
181+
SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0');
182+
183+
+-- https://github.com/apache/datafusion-comet/issues/2215
184+
+--SET spark.comet.exec.enabled=false
185+
-- [SPARK-28024] Incorrect value when out of range
186+
SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i;
187+
175188
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
176189
index fac23b4a26f..2b73732c33f 100644
177190
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/int8.sql
@@ -866,7 +879,7 @@ index c26757c9cff..d55775f09d7 100644
866879
protected val baseResourcePath = {
867880
// use the same way as `SQLQueryTestSuite` to get the resource path
868881
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
869-
index 793a0da6a86..e48e74091cb 100644
882+
index 793a0da6a86..181bfc16e4b 100644
870883
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
871884
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
872885
@@ -1521,7 +1521,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
@@ -879,7 +892,27 @@ index 793a0da6a86..e48e74091cb 100644
879892
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") {
880893
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
881894
}
882-
@@ -4497,7 +4498,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
895+
@@ -4459,7 +4460,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
896+
}
897+
898+
test("SPARK-39166: Query context of binary arithmetic should be serialized to executors" +
899+
- " when WSCG is off") {
900+
+ " when WSCG is off",
901+
+ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
902+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
903+
SQLConf.ANSI_ENABLED.key -> "true") {
904+
withTable("t") {
905+
@@ -4480,7 +4482,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
906+
}
907+
908+
test("SPARK-39175: Query context of Cast should be serialized to executors" +
909+
- " when WSCG is off") {
910+
+ " when WSCG is off",
911+
+ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
912+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
913+
SQLConf.ANSI_ENABLED.key -> "true") {
914+
withTable("t") {
915+
@@ -4497,14 +4500,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
883916
val msg = intercept[SparkException] {
884917
sql(query).collect()
885918
}.getMessage
@@ -892,6 +925,15 @@ index 793a0da6a86..e48e74091cb 100644
892925
}
893926
}
894927
}
928+
}
929+
930+
test("SPARK-39190,SPARK-39208,SPARK-39210: Query context of decimal overflow error should " +
931+
- "be serialized to executors when WSCG is off") {
932+
+ "be serialized to executors when WSCG is off",
933+
+ IgnoreComet("https://github.com/apache/datafusion-comet/issues/2215")) {
934+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
935+
SQLConf.ANSI_ENABLED.key -> "true") {
936+
withTable("t") {
895937
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
896938
index fa1a64460fc..1d2e215d6a3 100644
897939
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala

native/core/src/execution/planner.rs

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ use datafusion::{
6262
prelude::SessionContext,
6363
};
6464
use datafusion_comet_spark_expr::{
65-
create_comet_physical_fun, create_modulo_expr, create_negate_expr, BinaryOutputStyle,
66-
BloomFilterAgg, BloomFilterMightContain, EvalMode, SparkHour, SparkMinute, SparkSecond,
65+
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, create_modulo_expr,
66+
create_negate_expr, BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, EvalMode,
67+
SparkHour, SparkMinute, SparkSecond,
6768
};
6869

6970
use crate::execution::operators::ExecutionError::GeneralError;
@@ -242,8 +243,6 @@ impl PhysicalPlanner {
242243
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
243244
match spark_expr.expr_struct.as_ref().unwrap() {
244245
ExprStruct::Add(expr) => {
245-
// TODO respect ANSI eval mode
246-
// https://github.com/apache/datafusion-comet/issues/536
247246
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
248247
self.create_binary_expr(
249248
expr.left.as_ref().unwrap(),
@@ -255,8 +254,6 @@ impl PhysicalPlanner {
255254
)
256255
}
257256
ExprStruct::Subtract(expr) => {
258-
// TODO respect ANSI eval mode
259-
// https://github.com/apache/datafusion-comet/issues/535
260257
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
261258
self.create_binary_expr(
262259
expr.left.as_ref().unwrap(),
@@ -268,8 +265,6 @@ impl PhysicalPlanner {
268265
)
269266
}
270267
ExprStruct::Multiply(expr) => {
271-
// TODO respect ANSI eval mode
272-
// https://github.com/apache/datafusion-comet/issues/534
273268
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
274269
self.create_binary_expr(
275270
expr.left.as_ref().unwrap(),
@@ -281,8 +276,6 @@ impl PhysicalPlanner {
281276
)
282277
}
283278
ExprStruct::Divide(expr) => {
284-
// TODO respect ANSI eval mode
285-
// https://github.com/apache/datafusion-comet/issues/533
286279
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
287280
self.create_binary_expr(
288281
expr.left.as_ref().unwrap(),
@@ -1010,21 +1003,25 @@ impl PhysicalPlanner {
10101003
}
10111004
_ => {
10121005
let data_type = return_type.map(to_arrow_datatype).unwrap();
1013-
if eval_mode == EvalMode::Try && data_type.is_integer() {
1006+
if [EvalMode::Try, EvalMode::Ansi].contains(&eval_mode)
1007+
&& (data_type.is_integer()
1008+
|| (data_type.is_floating() && op == DataFusionOperator::Divide))
1009+
{
10141010
let op_str = match op {
10151011
DataFusionOperator::Plus => "checked_add",
10161012
DataFusionOperator::Minus => "checked_sub",
10171013
DataFusionOperator::Multiply => "checked_mul",
10181014
DataFusionOperator::Divide => "checked_div",
10191015
_ => {
1020-
todo!("Operator yet to be implemented!");
1016+
todo!("ANSI mode for Operator yet to be implemented!");
10211017
}
10221018
};
1023-
let fun_expr = create_comet_physical_fun(
1019+
let fun_expr = create_comet_physical_fun_with_eval_mode(
10241020
op_str,
10251021
data_type.clone(),
10261022
&self.session_ctx.state(),
10271023
None,
1024+
eval_mode,
10281025
)?;
10291026
Ok(Arc::new(ScalarFunctionExpr::new(
10301027
op_str,

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::math_funcs::modulo_expr::spark_modulo;
2121
use crate::{
2222
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
2323
spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
24-
spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value,
24+
spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, EvalMode,
2525
SparkBitwiseCount, SparkBitwiseGet, SparkBitwiseNot, SparkDateTrunc, SparkStringSpace,
2626
};
2727
use arrow::datatypes::DataType;
@@ -64,6 +64,15 @@ macro_rules! make_comet_scalar_udf {
6464
);
6565
Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
6666
}};
67+
($name:expr, $func:ident, $data_type:ident, $eval_mode:ident) => {{
68+
let scalar_func = CometScalarFunction::new(
69+
$name.to_string(),
70+
Signature::variadic_any(Volatility::Immutable),
71+
$data_type.clone(),
72+
Arc::new(move |args| $func(args, &$data_type, $eval_mode)),
73+
);
74+
Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
75+
}};
6776
}
6877

6978
/// Create a physical scalar function.
@@ -72,6 +81,23 @@ pub fn create_comet_physical_fun(
7281
data_type: DataType,
7382
registry: &dyn FunctionRegistry,
7483
fail_on_error: Option<bool>,
84+
) -> Result<Arc<ScalarUDF>, DataFusionError> {
85+
create_comet_physical_fun_with_eval_mode(
86+
fun_name,
87+
data_type,
88+
registry,
89+
fail_on_error,
90+
EvalMode::Legacy,
91+
)
92+
}
93+
94+
/// Create a physical scalar function with eval mode. Goal is to deprecate above function once all the operators have ANSI support
95+
pub fn create_comet_physical_fun_with_eval_mode(
96+
fun_name: &str,
97+
data_type: DataType,
98+
registry: &dyn FunctionRegistry,
99+
fail_on_error: Option<bool>,
100+
eval_mode: EvalMode,
75101
) -> Result<Arc<ScalarUDF>, DataFusionError> {
76102
match fun_name {
77103
"ceil" => {
@@ -117,16 +143,16 @@ pub fn create_comet_physical_fun(
117143
)
118144
}
119145
"checked_add" => {
120-
make_comet_scalar_udf!("checked_add", checked_add, data_type)
146+
make_comet_scalar_udf!("checked_add", checked_add, data_type, eval_mode)
121147
}
122148
"checked_sub" => {
123-
make_comet_scalar_udf!("checked_sub", checked_sub, data_type)
149+
make_comet_scalar_udf!("checked_sub", checked_sub, data_type, eval_mode)
124150
}
125151
"checked_mul" => {
126-
make_comet_scalar_udf!("checked_mul", checked_mul, data_type)
152+
make_comet_scalar_udf!("checked_mul", checked_mul, data_type, eval_mode)
127153
}
128154
"checked_div" => {
129-
make_comet_scalar_udf!("checked_div", checked_div, data_type)
155+
make_comet_scalar_udf!("checked_div", checked_div, data_type, eval_mode)
130156
}
131157
"murmur3_hash" => {
132158
let func = Arc::new(spark_murmur3_hash);

native/spark-expr/src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ pub use conditional_funcs::*;
6464
pub use conversion_funcs::*;
6565
pub use nondetermenistic_funcs::*;
6666

67-
pub use comet_scalar_funcs::{create_comet_physical_fun, register_all_comet_functions};
67+
pub use comet_scalar_funcs::{
68+
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode,
69+
register_all_comet_functions,
70+
};
6871
pub use datetime_funcs::{
6972
spark_date_add, spark_date_sub, SparkDateTrunc, SparkHour, SparkMinute, SparkSecond,
7073
TimestampTruncExpr,

0 commit comments

Comments
 (0)