Skip to content

Commit 7b19a05

Browse files
committed
chore: check missingInput for Comet plan nodes
1 parent c3de884 commit 7b19a05

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package org.apache.spark.sql.comet
2222
import org.apache.spark.TaskContext
2323
import org.apache.spark.rdd.{ParallelCollectionRDD, RDD}
2424
import org.apache.spark.serializer.Serializer
25-
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
25+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, NamedExpression, SortOrder}
2626
import org.apache.spark.sql.catalyst.util.truncatedString
2727
import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec}
2828
import org.apache.spark.sql.execution.{SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode, UnsafeRowSerializer}
@@ -98,10 +98,15 @@ case class CometTakeOrderedAndProjectExec(
9898
child: SparkPlan)
9999
extends CometExec
100100
with UnaryExecNode {
101+
102+
override def producedAttributes: AttributeSet = outputSet ++ AttributeSet(projectList)
103+
101104
private lazy val writeMetrics =
102105
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
106+
103107
private lazy val readMetrics =
104108
SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
109+
105110
override lazy val metrics: Map[String, SQLMetric] = Map(
106111
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
107112
"numPartitions" -> SQLMetrics.createMetric(

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,15 +411,15 @@ abstract class CometTestBase
411411
// checks the plan node has no missing inputs
412412
// such nodes represented in plan with exclamation mark !
413413
// example: !CometWindowExec
414-
private def checkPlanNotMissingInput(plan: SparkPlan): Unit = {
414+
protected def checkPlanNotMissingInput(plan: SparkPlan): Unit = {
415415
def hasMissingInput(node: SparkPlan): Boolean = {
416416
node.missingInput.nonEmpty && node.children.nonEmpty
417417
}
418418

419419
val isCometNode = plan.nodeName.startsWith("Comet")
420420

421421
if (isCometNode && hasMissingInput(plan)) {
422-
assert(false, s"Plan ${plan.nodeName} has invalid missingInput")
422+
assert(false, s"Plan node `${plan.nodeName}` has invalid missingInput")
423423
}
424424

425425
// Otherwise recursively check children

0 commit comments

Comments
 (0)