@@ -21,6 +21,7 @@ package org.apache.comet.cost
2121
2222import scala .jdk .CollectionConverters ._
2323
24+ import org .apache .spark .internal .Logging
2425import org .apache .spark .sql .comet .{CometColumnarToRowExec , CometPlan , CometProjectExec }
2526import org .apache .spark .sql .comet .execution .shuffle .{CometColumnarShuffle , CometNativeShuffle , CometShuffleExchangeExec }
2627import org .apache .spark .sql .execution .SparkPlan
@@ -36,23 +37,24 @@ trait CometCostModel {
3637 def estimateCost (plan : SparkPlan ): CometCostEstimate
3738}
3839
39- class DefaultCometCostModel extends CometCostModel {
40+ class DefaultCometCostModel extends CometCostModel with Logging {
4041
4142 // optimistic default of 2x acceleration
4243 private val defaultAcceleration = 2.0
4344
4445 override def estimateCost (plan : SparkPlan ): CometCostEstimate = {
46+
47+ logTrace(s " estimateCost for $plan" )
48+
4549 // Walk the entire plan tree and accumulate costs
4650 var totalAcceleration = 0.0
4751 var operatorCount = 0
4852
4953 def collectOperatorCosts (node : SparkPlan ): Unit = {
5054 val operatorCost = estimateOperatorCost(node)
51- // scalastyle:off println
52- println(
53- s " [CostModel] Operator: ${node.getClass.getSimpleName}, " +
55+ logTrace(
56+ s " Operator: ${node.getClass.getSimpleName}, " +
5457 s " Cost: ${operatorCost.acceleration}" )
55- // scalastyle:on println
5658 totalAcceleration += operatorCost.acceleration
5759 operatorCount += 1
5860
@@ -70,11 +72,9 @@ class DefaultCometCostModel extends CometCostModel {
7072 1.0 // No acceleration if no operators
7173 }
7274
73- // scalastyle:off println
74- println(
75- s " [CostModel] Plan: ${plan.getClass.getSimpleName}, Total operators: $operatorCount, " +
75+ logTrace(
76+ s " Plan: ${plan.getClass.getSimpleName}, Total operators: $operatorCount, " +
7677 s " Average acceleration: $averageAcceleration" )
77- // scalastyle:on println
7878
7979 CometCostEstimate (averageAcceleration)
8080 }
@@ -83,29 +83,21 @@ class DefaultCometCostModel extends CometCostModel {
8383 private def estimateOperatorCost (plan : SparkPlan ): CometCostEstimate = {
8484 val result = plan match {
8585 case op : CometProjectExec =>
86- // scalastyle:off println
87- println(s " [CostModel] CometProjectExec found - evaluating expressions " )
88- // scalastyle:on println
86+ logTrace(" CometProjectExec found - evaluating expressions" )
8987 // Cast nativeOp to Operator and extract projection expressions
9088 val operator = op.nativeOp.asInstanceOf [OperatorOuterClass .Operator ]
9189 val projection = operator.getProjection
9290 val expressions = projection.getProjectListList.asScala
93- // scalastyle:off println
94- println(s " [CostModel] Found ${expressions.length} expressions in projection " )
95- // scalastyle:on println
91+ logTrace(s " Found ${expressions.length} expressions in projection " )
9692
9793 val costs = expressions.map { expr =>
9894 val cost = estimateCometExpressionCost(expr)
99- // scalastyle:off println
100- println(s " [CostModel] Expression cost: $cost" )
101- // scalastyle:on println
95+ logTrace(s " Expression cost: $cost" )
10296 cost
10397 }
10498 val total = costs.sum
10599 val average = total / expressions.length.toDouble
106- // scalastyle:off println
107- println(s " [CostModel] CometProjectExec total cost: $total, average: $average" )
108- // scalastyle:on println
100+ logTrace(s " CometProjectExec total cost: $total, average: $average" )
109101 CometCostEstimate (average)
110102
111103 case op : CometShuffleExchangeExec =>
@@ -121,38 +113,28 @@ class DefaultCometCostModel extends CometCostModel {
121113 case _ : CometColumnarToRowExec =>
122114 CometCostEstimate (1.0 )
123115 case _ : CometPlan =>
124- // scalastyle:off println
125- println(s " [CostModel] Generic CometPlan: ${plan.getClass.getSimpleName}" )
126- // scalastyle:on println
116+ logTrace(s " Generic CometPlan: ${plan.getClass.getSimpleName}" )
127117 CometCostEstimate (defaultAcceleration)
128118 case _ =>
129- // scalastyle:off println
130- println(s " [CostModel] Non-Comet operator: ${plan.getClass.getSimpleName}" )
131- // scalastyle:on println
119+ logTrace(s " Non-Comet operator: ${plan.getClass.getSimpleName}" )
132120 // Spark operator
133121 CometCostEstimate (1.0 )
134122 }
135123
136- // scalastyle:off println
137- println(s " [CostModel] ${plan.getClass.getSimpleName} -> acceleration: ${result.acceleration}" )
138- // scalastyle:on println
124+ logTrace(s " ${plan.getClass.getSimpleName} -> acceleration: ${result.acceleration}" )
139125 result
140126 }
141127
142128 /** Estimate the cost of a Comet protobuf expression */
143129 private def estimateCometExpressionCost (expr : ExprOuterClass .Expr ): Double = {
144130 val result = expr.getExprStructCase match {
145131 // Handle specialized expression types
146- case ExprOuterClass .Expr .ExprStructCase .SUBSTRING =>
147- // scalastyle:off println
148- println(s " [CostModel] Expression: SUBSTRING -> 6.3 " )
149- // scalastyle:on println
150- 6.3
132+ case ExprOuterClass .Expr .ExprStructCase .SUBSTRING => 6.3
151133
152134 // Handle generic scalar functions
153135 case ExprOuterClass .Expr .ExprStructCase .SCALARFUNC =>
154136 val funcName = expr.getScalarFunc.getFunc
155- val cost = funcName match {
137+ funcName match {
156138 // String expression numbers from CometStringExpressionBenchmark
157139 case " ascii" => 0.6
158140 case " octet_length" => 0.6
@@ -171,17 +153,11 @@ class DefaultCometCostModel extends CometCostModel {
171153 case " translate" => 0.8
172154 case _ => defaultAcceleration
173155 }
174- // scalastyle:off println
175- println(s " [CostModel] Expression: SCALARFUNC( $funcName) -> $cost" )
176- // scalastyle:on println
177- cost
178156
179157 case _ =>
180- // scalastyle:off println
181- println(
182- s " [CostModel] Expression: Unknown type ${expr.getExprStructCase} -> " +
158+ logTrace(
159+ s " Expression: Unknown type ${expr.getExprStructCase} -> " +
183160 s " $defaultAcceleration" )
184- // scalastyle:on println
185161 defaultAcceleration
186162 }
187163 result
0 commit comments