Skip to content

Commit 8bfb110

Browse files
committed
[WIP] support PartialMerge
1 parent 3a68703 commit 8bfb110

File tree

5 files changed

+51
-42
lines changed

5 files changed

+51
-42
lines changed

native/core/src/execution/planner.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,6 @@ impl PhysicalPlanner {
864864
OpStruct::HashAgg(agg) => {
865865
assert_eq!(children.len(), 1);
866866
let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?;
867-
868867
let group_exprs: PhyExprResult = agg
869868
.grouping_exprs
870869
.iter()
@@ -877,12 +876,16 @@ impl PhysicalPlanner {
877876
let group_by = PhysicalGroupBy::new_single(group_exprs?);
878877
let schema = child.schema();
879878

880-
let mode = if agg.mode == 0 {
881-
DFAggregateMode::Partial
882-
} else {
879+
// dbg!(agg);
880+
881+
let mode = if agg.mode.contains(&2) {
883882
DFAggregateMode::Final
883+
} else {
884+
DFAggregateMode::Partial
884885
};
885886

887+
dbg!(&schema);
888+
886889
let agg_exprs: PhyAggResult = agg
887890
.agg_exprs
888891
.iter()

native/proto/src/proto/operator.proto

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ message HashAggregate {
207207
repeated spark.spark_expression.Expr grouping_exprs = 1;
208208
repeated spark.spark_expression.AggExpr agg_exprs = 2;
209209
repeated spark.spark_expression.Expr result_exprs = 3;
210-
AggregateMode mode = 5;
210+
// Spark can have both partial and partialMerge together
211+
repeated AggregateMode mode = 5;
211212
}
212213

213214
message Limit {
@@ -249,7 +250,10 @@ message ParquetWriter {
249250

250251
enum AggregateMode {
251252
Partial = 0;
252-
Final = 1;
253+
PartialMerge = 1;
254+
Final = 2;
255+
// Spark supports the COMPLETE but it a stub for now
256+
Complete = 3;
253257
}
254258

255259
message Expand {

spark/src/main/scala/org/apache/comet/serde/namedExpressions.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ object CometAttributeReference extends CometExpressionSerde[AttributeReference]
4747
if (dataType.isDefined) {
4848
if (binding) {
4949
// Spark may produce unresolvable attributes in some cases,
50-
// for example https://github.com/apache/datafusion-comet/issues/925.
51-
// So, we allow the binding to fail.
50+
// for example partial aggregation or https://github.com/apache/datafusion-comet/issues/925.
5251
val boundRef: Any = BindReferences
5352
.bindReference(attr, inputs, allowFailures = true)
5453

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.broadcast.Broadcast
3131
import org.apache.spark.rdd.RDD
3232
import org.apache.spark.sql.catalyst.InternalRow
3333
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder}
34-
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Final, Partial, PartialMerge}
34+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Complete, Final, Partial, PartialMerge}
3535
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
3636
import org.apache.spark.sql.catalyst.plans._
3737
import org.apache.spark.sql.catalyst.plans.physical._
@@ -1066,12 +1066,11 @@ trait CometBaseAggregate {
10661066

10671067
val modes = aggregate.aggregateExpressions.map(_.mode).distinct
10681068
// In distinct aggregates there can be a combination of modes
1069-
val multiMode = modes.size > 1
10701069
// For a final mode HashAggregate, we only need to transform the HashAggregate
10711070
// if there is Comet partial aggregation.
10721071
val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty
10731072

1074-
if (multiMode || sparkFinalMode) {
1073+
if (sparkFinalMode) {
10751074
return None
10761075
}
10771076

@@ -1144,33 +1143,34 @@ trait CometBaseAggregate {
11441143
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
11451144
Some(builder.setHashAgg(hashAggBuilder).build())
11461145
} else {
1147-
val modes = aggregateExpressions.map(_.mode).distinct
1148-
1149-
if (modes.size != 1) {
1150-
// This shouldn't happen as all aggregation expressions should share the same mode.
1151-
// Fallback to Spark nevertheless here.
1152-
withInfo(aggregate, "All aggregate expressions do not have the same mode")
1153-
return None
1154-
}
1155-
1156-
val mode = modes.head match {
1157-
case Partial => CometAggregateMode.Partial
1158-
case Final => CometAggregateMode.Final
1159-
case _ =>
1160-
withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}")
1161-
return None
1162-
}
1146+
// `output` is only used when `binding` is true (i.e., non-Final)
1147+
val output = child.output
11631148

11641149
// In final mode, the aggregate expressions are bound to the output of the
11651150
// child and partial aggregate expressions buffer attributes produced by partial
11661151
// aggregation. This is done in Spark `HashAggregateExec` internally. In Comet,
11671152
// we don't have to do this because we don't use the merging expression.
1168-
val binding = mode != CometAggregateMode.Final
1169-
// `output` is only used when `binding` is true (i.e., non-Final)
1170-
val output = child.output
1171-
1172-
val aggExprs =
1173-
aggregateExpressions.map(aggExprToProto(_, output, binding, aggregate.conf))
1153+
//
1154+
// It is possible to have multiple modes for queries with DISTINCT agg expression
1155+
// So Spark can Partial and PartialMerge at the same time
1156+
val (aggExprs, aggModes) =
1157+
aggregateExpressions
1158+
.map(a =>
1159+
(
1160+
aggExprToProto(
1161+
a,
1162+
output,
1163+
a.mode != PartialMerge && a.mode != Final,
1164+
aggregate.conf),
1165+
a.mode match {
1166+
case Partial => CometAggregateMode.Partial
1167+
case PartialMerge => CometAggregateMode.PartialMerge
1168+
case Final => CometAggregateMode.Final
1169+
case mode =>
1170+
withInfo(aggregate, s"Unsupported Aggregation Mode $mode")
1171+
return None
1172+
}))
1173+
.unzip
11741174

11751175
if (aggExprs.exists(_.isEmpty)) {
11761176
withInfo(
@@ -1185,7 +1185,9 @@ trait CometBaseAggregate {
11851185
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
11861186
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
11871187
hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
1188-
if (mode == CometAggregateMode.Final) {
1188+
// Spark sending Final separately only,
1189+
// so if any entry is Final means everything else is also Final
1190+
if (modes.contains(Final)) {
11891191
val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
11901192
val resultExprs = resultExpressions.map(exprToProto(_, attributes))
11911193
if (resultExprs.exists(_.isEmpty)) {
@@ -1197,7 +1199,7 @@ trait CometBaseAggregate {
11971199
}
11981200
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
11991201
}
1200-
hashAggBuilder.setModeValue(mode.getNumber)
1202+
hashAggBuilder.addAllMode(aggModes.asJava)
12011203
Some(builder.setHashAgg(hashAggBuilder).build())
12021204
} else {
12031205
val allChildren: Seq[Expression] =
@@ -1323,8 +1325,6 @@ case class CometHashAggregateExec(
13231325
// modes is empty too. If aggExprs is not empty, we need to verify all the
13241326
// aggregates have the same mode.
13251327
val modes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct
1326-
assert(modes.length == 1 || modes.isEmpty)
1327-
val mode = modes.headOption
13281328

13291329
override def producedAttributes: AttributeSet = outputSet ++ AttributeSet(resultExpressions)
13301330

@@ -1341,7 +1341,7 @@ case class CometHashAggregateExec(
13411341
}
13421342

13431343
override def stringArgs: Iterator[Any] =
1344-
Iterator(input, mode, groupingExpressions, aggregateExpressions, child)
1344+
Iterator(input, modes, groupingExpressions, aggregateExpressions, child)
13451345

13461346
override def equals(obj: Any): Boolean = {
13471347
obj match {
@@ -1350,7 +1350,7 @@ case class CometHashAggregateExec(
13501350
this.groupingExpressions == other.groupingExpressions &&
13511351
this.aggregateExpressions == other.aggregateExpressions &&
13521352
this.input == other.input &&
1353-
this.mode == other.mode &&
1353+
this.modes == other.modes &&
13541354
this.child == other.child &&
13551355
this.serializedPlanOpt == other.serializedPlanOpt
13561356
case _ =>
@@ -1359,7 +1359,7 @@ case class CometHashAggregateExec(
13591359
}
13601360

13611361
override def hashCode(): Int =
1362-
Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, mode, child)
1362+
Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, modes, child)
13631363

13641364
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
13651365
}

spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,9 +481,12 @@ class CometExecSuite extends CometTestBase {
481481
case s: CometHashAggregateExec => s
482482
}.get
483483

484-
assert(agg.mode.isDefined && agg.mode.get.isInstanceOf[AggregateMode])
484+
assert(
485+
agg.modes.headOption.isDefined && agg.modes.headOption.get.isInstanceOf[AggregateMode])
485486
val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec]
486-
assert(newAgg.mode.isDefined && newAgg.mode.get.isInstanceOf[AggregateMode])
487+
assert(
488+
newAgg.modes.headOption.isDefined && newAgg.modes.headOption.get
489+
.isInstanceOf[AggregateMode])
487490
}
488491
}
489492

0 commit comments

Comments
 (0)