Skip to content

Commit 2ce060b

Browse files
committed
impl_try_mode
1 parent 2124287 commit 2ce060b

File tree

5 files changed

+42
-22
lines changed

5 files changed

+42
-22
lines changed

native/core/src/execution/planner.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,12 +1876,7 @@ impl PhysicalPlanner {
18761876
AggregateExprBuilder::new(Arc::new(func), vec![child])
18771877
}
18781878
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
1879-
// let eval_mode = let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
1880-
let eval_mode = if expr.fail_on_error {
1881-
EvalMode::Ansi
1882-
} else {
1883-
EvalMode::Legacy
1884-
};
1879+
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
18851880
let func =
18861881
AggregateUDF::new_from_impl(SumInteger::try_new(datatype, eval_mode)?);
18871882
AggregateExprBuilder::new(Arc::new(func), vec![child])

native/proto/src/proto/expr.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ message Count {
121121
message Sum {
122122
Expr child = 1;
123123
DataType datatype = 2;
124-
bool fail_on_error = 3;
124+
EvalMode eval_mode = 3;
125125
}
126126

127127
message Min {

native/spark-expr/src/agg_funcs/sum_int.rs

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ impl SumInteger {
4141
pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult<Self> {
4242
// The `data_type` is the SUM result type passed from Spark side which should i64
4343
println!("data type: {:?} eval_mode {:?}", data_type, eval_mode);
44+
4445
match data_type {
4546
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(Self {
4647
signature: Signature::user_defined(Immutable),
@@ -121,14 +122,19 @@ impl Accumulator for SumIntegerAccumulator {
121122
})?;
122123
println!("sum : {:?}, v : {:?}", sum, v);
123124
match eval_mode {
124-
EvalMode::Legacy | EvalMode::Try => {
125+
EvalMode::Legacy => {
125126
sum = v.add_wrapping(sum);
126127
}
127-
EvalMode::Ansi => {
128+
EvalMode::Ansi | EvalMode::Try => {
128129
match v.add_checked(sum) {
129130
Ok(v) => sum = v,
130131
Err(e) => {
131-
return Err(DataFusionError::from(arithmetic_overflow_error("integer")))
132+
if (eval_mode == EvalMode::Ansi){
133+
return Err(DataFusionError::from(arithmetic_overflow_error("integer")))
134+
}
135+
else {
136+
sum = None.unwrap();
137+
}
132138
}
133139
};
134140
}
@@ -228,12 +234,18 @@ impl Accumulator for SumIntegerAccumulator {
228234
);
229235
let that_sum = states[0].as_primitive::<Int64Type>();
230236
match self.eval_mode {
231-
EvalMode::Legacy | EvalMode::Try => {
237+
EvalMode::Legacy => {
232238
self.sum.add_wrapping(that_sum.value(0));
233239
}
234-
EvalMode::Ansi => match self.sum.add_checked(that_sum.value(0)) {
240+
EvalMode::Ansi | EvalMode::Try => match self.sum.add_checked(that_sum.value(0)) {
235241
Ok(v) => self.sum = v,
236-
Err(e) => return Err(DataFusionError::from(arithmetic_overflow_error("integer"))),
242+
Err(e) =>
243+
if (self.eval_mode == EvalMode::Ansi){
244+
return Err(DataFusionError::from(arithmetic_overflow_error("integer"))),
245+
}
246+
else{
247+
self.sum = None.unwrap();
248+
}
237249
},
238250
}
239251
Ok(())
@@ -271,14 +283,19 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
271283

272284
for (&group_index, &value) in iter {
273285
match self.eval_mode {
274-
EvalMode::Legacy | EvalMode::Try => {
286+
EvalMode::Legacy => {
275287
self.sums[group_index] = self.sums[group_index].add_wrapping(value);
276288
}
277-
EvalMode::Ansi => {
289+
EvalMode::Ansi | EvalMode::Try => {
278290
match self.sums[group_index].add_checked(value) {
279291
Ok(v) => self.sums[group_index] = v,
280292
Err(e) => {
281-
return Err(DataFusionError::from(arithmetic_overflow_error("integer")))
293+
if (self.eval_mode == EvalMode::Ansi){
294+
return Err(DataFusionError::from(arithmetic_overflow_error("integer")))
295+
}
296+
else{
297+
self.sums[group_index] = None.unwrap();
298+
}
282299
}
283300
};
284301
}
@@ -333,11 +350,16 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
333350
EvalMode::Legacy | EvalMode::Try => {
334351
self.sums[group_index] = self.sums[group_index].add_wrapping(value);
335352
}
336-
EvalMode::Ansi => {
353+
EvalMode::Ansi | EvalMode::Try => {
337354
match self.sums[group_index].add_checked(value) {
338355
Ok(v) => self.sums[group_index] = v,
339356
Err(e) => {
340-
return Err(DataFusionError::Internal("integer overflow".to_string()))
357+
if (self.eval_mode == EvalMode::Ansi){
358+
return Err(DataFusionError::from(arithmetic_overflow_error("integer")))
359+
}
360+
else{
361+
self.sums[group_index] = None.unwrap();
362+
}
341363
}
342364
};
343365
}

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ import org.apache.spark.sql.types.{ByteType, DecimalType, IntegerType, LongType,
2828

2929
import org.apache.comet.CometConf
3030
import org.apache.comet.CometSparkSessionExtensions.withInfo
31-
import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType}
31+
import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType}
32+
import org.apache.comet.shims.CometEvalModeUtil
3233

3334
object CometMin extends CometAggregateExpressionSerde[Min] {
3435

@@ -201,14 +202,16 @@ object CometSum extends CometAggregateExpressionSerde[Sum] {
201202
return None
202203
}
203204

205+
val evalMode = sum.evalMode
206+
204207
val childExpr = exprToProto(sum.child, inputs, binding)
205208
val dataType = serializeDataType(sum.dataType)
206209

207210
if (childExpr.isDefined && dataType.isDefined) {
208211
val builder = ExprOuterClass.Sum.newBuilder()
209212
builder.setChild(childExpr.get)
210213
builder.setDatatype(dataType.get)
211-
builder.setFailOnError(sum.evalMode == EvalMode.ANSI)
214+
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(evalMode)))
212215

213216
Some(
214217
ExprOuterClass.AggExpr

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3069,16 +3069,16 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
30693069
// | """.stripMargin)
30703070
// checkSparkAnswerAndOperator(res)
30713071
}
3072-
// res.show(10, false)
3072+
// res.show(10, false)
30733073
// checkSparkMaybeThrows(res) match {
30743074
// case (Some(sparkExc), Some(cometExc)) =>
30753075
// assert(cometExc.getMessage.contains("error"))
30763076
// assert(sparkExc.getMessage.contains("overflow"))
30773077
// case _ => fail("Exception should be thrown")
30783078
// }
3079-
}
30803079
}
30813080
}
3081+
}
30823082

30833083
}
30843084

0 commit comments

Comments
 (0)