Skip to content

Commit af21ae9

Browse files
authored
feat: add getSupportLevel for aggregates (apache#2777)
1 parent 5515741 commit af21ae9

File tree

3 files changed

+64
-29
lines changed

3 files changed

+64
-29
lines changed

spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ trait CometAggregateExpressionSerde[T <: AggregateFunction] {
3939
*/
4040
def getExprConfigName(expr: T): String = expr.getClass.getSimpleName
4141

42+
/**
43+
* Determine the support level of the expression based on its attributes.
44+
*
45+
* @param expr
46+
* The Spark expression.
47+
* @return
48+
* Support level (Compatible, Incompatible, or Unsupported).
49+
*/
50+
def getSupportLevel(expr: T): SupportLevel = Compatible(None)
51+
4252
/**
4353
* Convert a Spark expression into a protocol buffer representation that can be passed into
4454
* native code.

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,35 @@ object QueryPlanSerde extends Logging with CometExprShim {
398398
s"${CometConf.getExprEnabledConfigKey(exprConfName)}=true to enable it.")
399399
return None
400400
}
401-
aggHandler.convert(aggExpr, fn, inputs, binding, conf)
401+
aggHandler.getSupportLevel(fn) match {
402+
case Unsupported(notes) =>
403+
withInfo(fn, notes.getOrElse(""))
404+
None
405+
case Incompatible(notes) =>
406+
val exprAllowIncompat = CometConf.isExprAllowIncompat(exprConfName)
407+
if (exprAllowIncompat || CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.get()) {
408+
if (notes.isDefined) {
409+
logWarning(
410+
s"Comet supports $fn when ${CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key}=true " +
411+
s"but has notes: ${notes.get}")
412+
}
413+
aggHandler.convert(aggExpr, fn, inputs, binding, conf)
414+
} else {
415+
val optionalNotes = notes.map(str => s" ($str)").getOrElse("")
416+
withInfo(
417+
fn,
418+
s"$fn is not fully compatible with Spark$optionalNotes. To enable it anyway, " +
419+
s"set ${CometConf.getExprAllowIncompatConfigKey(exprConfName)}=true, or set " +
420+
s"${CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key}=true to enable all " +
421+
s"incompatible expressions. ${CometConf.COMPAT_GUIDE}.")
422+
None
423+
}
424+
case Compatible(notes) =>
425+
if (notes.isDefined) {
426+
logWarning(s"Comet supports $fn but has notes: ${notes.get}")
427+
}
428+
aggHandler.convert(aggExpr, fn, inputs, binding, conf)
429+
}
402430
case _ =>
403431
withInfo(
404432
aggExpr,

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,18 @@ object CometCount extends CometAggregateExpressionSerde[Count] {
149149
}
150150

151151
object CometAverage extends CometAggregateExpressionSerde[Average] {
152+
153+
override def getSupportLevel(avg: Average): SupportLevel = {
154+
avg.evalMode match {
155+
case EvalMode.ANSI =>
156+
Incompatible(Some("ANSI mode is not supported"))
157+
case EvalMode.TRY =>
158+
Incompatible(Some("TRY mode is not supported"))
159+
case _ =>
160+
Compatible()
161+
}
162+
}
163+
152164
override def convert(
153165
aggExpr: AggregateExpression,
154166
avg: Average,
@@ -161,20 +173,6 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
161173
return None
162174
}
163175

164-
avg.evalMode match {
165-
case EvalMode.ANSI if !CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.get() =>
166-
withInfo(
167-
aggExpr,
168-
"ANSI mode is not supported. Set " +
169-
s"${CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key}=true to allow it anyway")
170-
return None
171-
case EvalMode.TRY =>
172-
withInfo(aggExpr, "TRY mode is not supported")
173-
return None
174-
case _ =>
175-
// supported
176-
}
177-
178176
val child = avg.child
179177
val childExpr = exprToProto(child, inputs, binding)
180178
val dataType = serializeDataType(avg.dataType)
@@ -211,7 +209,20 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
211209
}
212210
}
213211
}
212+
214213
object CometSum extends CometAggregateExpressionSerde[Sum] {
214+
215+
override def getSupportLevel(sum: Sum): SupportLevel = {
216+
sum.evalMode match {
217+
case EvalMode.ANSI =>
218+
Incompatible(Some("ANSI mode is not supported"))
219+
case EvalMode.TRY =>
220+
Incompatible(Some("TRY mode is not supported"))
221+
case _ =>
222+
Compatible()
223+
}
224+
}
225+
215226
override def convert(
216227
aggExpr: AggregateExpression,
217228
sum: Sum,
@@ -224,20 +235,6 @@ object CometSum extends CometAggregateExpressionSerde[Sum] {
224235
return None
225236
}
226237

227-
sum.evalMode match {
228-
case EvalMode.ANSI if !CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.get() =>
229-
withInfo(
230-
aggExpr,
231-
"ANSI mode is not supported. Set " +
232-
s"${CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key}=true to allow it anyway")
233-
return None
234-
case EvalMode.TRY =>
235-
withInfo(aggExpr, "TRY mode is not supported")
236-
return None
237-
case _ =>
238-
// supported
239-
}
240-
241238
val childExpr = exprToProto(sum.child, inputs, binding)
242239
val dataType = serializeDataType(sum.dataType)
243240

0 commit comments

Comments
 (0)