Skip to content

Commit 2124287

Browse files
committed
conf_bug_fix
1 parent a197fd4 commit 2124287

File tree

2 files changed

+76
-81
lines changed

2 files changed

+76
-81
lines changed

native/spark-expr/src/agg_funcs/sum_int.rs

Lines changed: 51 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::EvalMode;
18+
use crate::{arithmetic_overflow_error, EvalMode};
1919
use arrow::array::{
2020
cast::AsArray, Array, ArrayBuilder, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType,
2121
BooleanArray, Int64Array, PrimitiveArray,
@@ -39,8 +39,8 @@ pub struct SumInteger {
3939

4040
impl SumInteger {
4141
pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult<Self> {
42-
// The `data_type` is the SUM result type passed from Spark side
43-
println!("data type: {:?}", data_type);
42+
// The `data_type` is the SUM result type passed from Spark side which should i64
43+
println!("data type: {:?} eval_mode {:?}", data_type, eval_mode);
4444
match data_type {
4545
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(Self {
4646
signature: Signature::user_defined(Immutable),
@@ -75,14 +75,14 @@ impl AggregateUDFImpl for SumInteger {
7575
}
7676

7777
fn accumulator(&self, acc_args: AccumulatorArgs) -> DFResult<Box<dyn Accumulator>> {
78-
Ok(Box::new(SumIntegerAccumulator::new()))
78+
Ok(Box::new(SumIntegerAccumulator::new(self.eval_mode)))
7979
}
8080

8181
fn create_groups_accumulator(
8282
&self,
8383
_args: AccumulatorArgs,
8484
) -> DFResult<Box<dyn GroupsAccumulator>> {
85-
Ok(Box::new(SumDecimalGroupsAccumulator::new(self.eval_mode)))
85+
Ok(Box::new(SumIntGroupsAccumulator::new(self.eval_mode)))
8686
}
8787
}
8888

@@ -94,10 +94,10 @@ struct SumIntegerAccumulator {
9494
}
9595

9696
impl SumIntegerAccumulator {
97-
fn new() -> Self {
97+
fn new(eval_mode: EvalMode) -> Self {
9898
Self {
9999
sum: 0,
100-
eval_mode: EvalMode::Legacy,
100+
eval_mode,
101101
input_data_type: DataType::Int64,
102102
}
103103
}
@@ -113,13 +113,13 @@ impl Accumulator for SumIntegerAccumulator {
113113
where
114114
T: ArrowPrimitiveType,
115115
{
116-
println!("match internal function data type: {:?}", sum);
117116
let len = int_array.len();
118117
for i in 0..int_array.len() {
119118
if !int_array.is_null(i) {
120119
let v = int_array.value(i).to_i64().ok_or_else(|| {
121120
DataFusionError::Internal("Failed to convert value to i64".to_string())
122121
})?;
122+
println!("sum : {:?}, v : {:?}", sum, v);
123123
match eval_mode {
124124
EvalMode::Legacy | EvalMode::Try => {
125125
sum = v.add_wrapping(sum);
@@ -128,7 +128,7 @@ impl Accumulator for SumIntegerAccumulator {
128128
match v.add_checked(sum) {
129129
Ok(v) => sum = v,
130130
Err(e) => {
131-
return Err(DataFusionError::Internal("error".to_string()))
131+
return Err(DataFusionError::from(arithmetic_overflow_error("integer")))
132132
}
133133
};
134134
}
@@ -157,53 +157,40 @@ impl Accumulator for SumIntegerAccumulator {
157157
);
158158
Ok(())
159159
} else {
160-
match values.data_type() {
161-
DataType::Int64 => {
162-
println!("match data type: {:?}", self.input_data_type);
163-
update_sum_internal(
164-
values
165-
.as_any()
166-
.downcast_ref::<PrimitiveArray<Int64Type>>()
167-
.unwrap(),
168-
self.eval_mode,
169-
self.sum,
170-
)?;
171-
}
172-
DataType::Int32 => {
173-
println!("match data type: {:?}", self.input_data_type);
174-
update_sum_internal(
175-
values
176-
.as_any()
177-
.downcast_ref::<PrimitiveArray<Int32Type>>()
178-
.unwrap(),
179-
self.eval_mode,
180-
self.sum,
181-
)?;
182-
}
183-
DataType::Int16 => {
184-
println!("match data type: {:?}", self.input_data_type);
185-
update_sum_internal(
186-
values
187-
.as_any()
188-
.downcast_ref::<PrimitiveArray<Int16Type>>()
189-
.unwrap(),
190-
self.eval_mode,
191-
self.sum,
192-
)?;
193-
}
194-
DataType::Int8 => {
195-
println!("match data type: {:?}", self.input_data_type);
196-
update_sum_internal(
197-
values
198-
.as_any()
199-
.downcast_ref::<PrimitiveArray<Int8Type>>()
200-
.unwrap(),
201-
self.eval_mode,
202-
self.sum,
203-
)?;
204-
}
160+
self.sum = match values.data_type() {
161+
DataType::Int64 => update_sum_internal(
162+
values
163+
.as_any()
164+
.downcast_ref::<PrimitiveArray<Int64Type>>()
165+
.unwrap(),
166+
self.eval_mode,
167+
self.sum,
168+
)?,
169+
DataType::Int32 => update_sum_internal(
170+
values
171+
.as_any()
172+
.downcast_ref::<PrimitiveArray<Int32Type>>()
173+
.unwrap(),
174+
self.eval_mode,
175+
self.sum,
176+
)?,
177+
DataType::Int16 => update_sum_internal(
178+
values
179+
.as_any()
180+
.downcast_ref::<PrimitiveArray<Int16Type>>()
181+
.unwrap(),
182+
self.eval_mode,
183+
self.sum,
184+
)?,
185+
DataType::Int8 => update_sum_internal(
186+
values
187+
.as_any()
188+
.downcast_ref::<PrimitiveArray<Int8Type>>()
189+
.unwrap(),
190+
self.eval_mode,
191+
self.sum,
192+
)?,
205193
_ => {
206-
println!("unsupported input data type: {:?}", self.input_data_type);
207194
panic!("Unsupported data type")
208195
}
209196
};
@@ -246,19 +233,19 @@ impl Accumulator for SumIntegerAccumulator {
246233
}
247234
EvalMode::Ansi => match self.sum.add_checked(that_sum.value(0)) {
248235
Ok(v) => self.sum = v,
249-
Err(e) => return Err(DataFusionError::Internal("error".to_string())),
236+
Err(e) => return Err(DataFusionError::from(arithmetic_overflow_error("integer"))),
250237
},
251238
}
252239
Ok(())
253240
}
254241
}
255242

256-
struct SumDecimalGroupsAccumulator {
243+
struct SumIntGroupsAccumulator {
257244
sums: Vec<i64>,
258245
eval_mode: EvalMode,
259246
}
260247

261-
impl SumDecimalGroupsAccumulator {
248+
impl SumIntGroupsAccumulator {
262249
fn new(eval_mode: EvalMode) -> Self {
263250
Self {
264251
sums: Vec::new(),
@@ -267,7 +254,7 @@ impl SumDecimalGroupsAccumulator {
267254
}
268255
}
269256

270-
impl GroupsAccumulator for SumDecimalGroupsAccumulator {
257+
impl GroupsAccumulator for SumIntGroupsAccumulator {
271258
fn update_batch(
272259
&mut self,
273260
values: &[ArrayRef],
@@ -285,13 +272,13 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
285272
for (&group_index, &value) in iter {
286273
match self.eval_mode {
287274
EvalMode::Legacy | EvalMode::Try => {
288-
self.sums[group_index].add_wrapping(value);
275+
self.sums[group_index] = self.sums[group_index].add_wrapping(value);
289276
}
290277
EvalMode::Ansi => {
291278
match self.sums[group_index].add_checked(value) {
292-
Ok(v) => v,
279+
Ok(v) => self.sums[group_index] = v,
293280
Err(e) => {
294-
return Err(DataFusionError::Internal("integer overflow".to_string()))
281+
return Err(DataFusionError::from(arithmetic_overflow_error("integer")))
295282
}
296283
};
297284
}
@@ -344,11 +331,11 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
344331
for (&group_index, &value) in iter {
345332
match self.eval_mode {
346333
EvalMode::Legacy | EvalMode::Try => {
347-
self.sums[group_index].add_wrapping(value);
334+
self.sums[group_index] = self.sums[group_index].add_wrapping(value);
348335
}
349336
EvalMode::Ansi => {
350337
match self.sums[group_index].add_checked(value) {
351-
Ok(v) => v,
338+
Ok(v) => self.sums[group_index] = v,
352339
Err(e) => {
353340
return Err(DataFusionError::Internal("integer overflow".to_string()))
354341
}

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3053,24 +3053,32 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
30533053
}
30543054

30553055
test("ANSI support for SUM function") {
3056-
val data = Seq((Int.MaxValue, 10), (1, 1))
3057-
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
3058-
withParquetTable(data, "tbl") {
3059-
val res = spark.sql("""
3060-
|SELECT
3061-
| SUM(_1)
3062-
| from tbl
3063-
| """.stripMargin)
3064-
3065-
res.show(10, false)
3066-
// checkSparkMaybeThrows(res) match {
3067-
// case (Some(sparkExc), Some(cometExc)) =>
3068-
// assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
3069-
// assert(sparkExc.getMessage.contains("overflow"))
3070-
// case _ => fail("Exception should be thrown")
3071-
// }
3056+
val batchSize = 10
3057+
Seq(true, false).foreach { dictionaryEnabled =>
3058+
withTempDir { dir =>
3059+
val path = new Path(dir.toURI.toString, "test_sum.parquet")
3060+
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = dictionaryEnabled, batchSize)
3061+
withParquetTable(path.toString, "tbl") {
3062+
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
3063+
spark.table("tbl").printSchema()
3064+
// val res = spark.sql(
3065+
// """
3066+
// |SELECT
3067+
// | SUM(_1)
3068+
// | from tbl
3069+
// | """.stripMargin)
3070+
// checkSparkAnswerAndOperator(res)
3071+
}
3072+
// res.show(10, false)
3073+
// checkSparkMaybeThrows(res) match {
3074+
// case (Some(sparkExc), Some(cometExc)) =>
3075+
// assert(cometExc.getMessage.contains("error"))
3076+
// assert(sparkExc.getMessage.contains("overflow"))
3077+
// case _ => fail("Exception should be thrown")
3078+
// }
3079+
}
3080+
}
30723081
}
3073-
}
30743082

30753083
}
30763084

0 commit comments

Comments
 (0)