Skip to content

Commit 162ee47

Browse files
committed
test
1 parent 8adc080 commit 162ee47

File tree

3 files changed

+181
-30
lines changed

3 files changed

+181
-30
lines changed

spark/src/main/scala/org/apache/comet/cost/CometCostEvaluator.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,11 @@ class CometCostEvaluator extends CostEvaluator with Logging {
8888
// - 0.8x acceleration -> cost = 1.25 (25% more cost)
8989
val costValue = 1.0 / estimate.acceleration
9090

91-
logDebug(
92-
s"Cost evaluation for ${plan.getClass.getSimpleName}: " +
91+
// scalastyle:off println
92+
println(
93+
s"[CostEvaluator] Plan: ${plan.getClass.getSimpleName}, " +
9394
s"acceleration=${estimate.acceleration}, cost=$costValue")
95+
// scalastyle:on println
9496

9597
// Create Cost object with the calculated value
9698
CometCost(costValue)

spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, Comet
2626
import org.apache.spark.sql.execution.SparkPlan
2727

2828
import org.apache.comet.DataTypeSupport
29-
import org.apache.comet.serde.ExprOuterClass
29+
import org.apache.comet.serde.{ExprOuterClass, OperatorOuterClass}
3030

3131
case class CometCostEstimate(acceleration: Double)
3232

@@ -48,6 +48,11 @@ class DefaultCometCostModel extends CometCostModel {
4848

4949
def collectOperatorCosts(node: SparkPlan): Unit = {
5050
val operatorCost = estimateOperatorCost(node)
51+
// scalastyle:off println
52+
println(
53+
s"[CostModel] Operator: ${node.getClass.getSimpleName}, " +
54+
s"Cost: ${operatorCost.acceleration}")
55+
// scalastyle:on println
5156
totalAcceleration += operatorCost.acceleration
5257
operatorCount += 1
5358

@@ -65,16 +70,44 @@ class DefaultCometCostModel extends CometCostModel {
6570
1.0 // No acceleration if no operators
6671
}
6772

73+
// scalastyle:off println
74+
println(
75+
s"[CostModel] Plan: ${plan.getClass.getSimpleName}, Total operators: $operatorCount, " +
76+
s"Average acceleration: $averageAcceleration")
77+
// scalastyle:on println
78+
6879
CometCostEstimate(averageAcceleration)
6980
}
7081

7182
/** Estimate the cost of a single operator */
7283
private def estimateOperatorCost(plan: SparkPlan): CometCostEstimate = {
73-
plan match {
84+
val result = plan match {
7485
case op: CometProjectExec =>
75-
val expressions = op.nativeOp.getProjection.getProjectListList.asScala
76-
val total: Double = expressions.map(estimateCometExpressionCost).sum
77-
CometCostEstimate(total / expressions.length.toDouble)
86+
// scalastyle:off println
87+
println(s"[CostModel] CometProjectExec found - evaluating expressions")
88+
// scalastyle:on println
89+
// Cast nativeOp to Operator and extract projection expressions
90+
val operator = op.nativeOp.asInstanceOf[OperatorOuterClass.Operator]
91+
val projection = operator.getProjection
92+
val expressions = projection.getProjectListList.asScala
93+
// scalastyle:off println
94+
println(s"[CostModel] Found ${expressions.length} expressions in projection")
95+
// scalastyle:on println
96+
97+
val costs = expressions.map { expr =>
98+
val cost = estimateCometExpressionCost(expr)
99+
// scalastyle:off println
100+
println(s"[CostModel] Expression cost: $cost")
101+
// scalastyle:on println
102+
cost
103+
}
104+
val total = costs.sum
105+
val average = total / expressions.length.toDouble
106+
// scalastyle:off println
107+
println(s"[CostModel] CometProjectExec total cost: $total, average: $average")
108+
// scalastyle:on println
109+
CometCostEstimate(average)
110+
78111
case op: CometShuffleExchangeExec =>
79112
op.shuffleType match {
80113
case CometNativeShuffle => CometCostEstimate(1.5)
@@ -88,22 +121,38 @@ class DefaultCometCostModel extends CometCostModel {
88121
case _: CometColumnarToRowExec =>
89122
CometCostEstimate(1.0)
90123
case _: CometPlan =>
124+
// scalastyle:off println
125+
println(s"[CostModel] Generic CometPlan: ${plan.getClass.getSimpleName}")
126+
// scalastyle:on println
91127
CometCostEstimate(defaultAcceleration)
92128
case _ =>
129+
// scalastyle:off println
130+
println(s"[CostModel] Non-Comet operator: ${plan.getClass.getSimpleName}")
131+
// scalastyle:on println
93132
// Spark operator
94133
CometCostEstimate(1.0)
95134
}
135+
136+
// scalastyle:off println
137+
println(s"[CostModel] ${plan.getClass.getSimpleName} -> acceleration: ${result.acceleration}")
138+
// scalastyle:on println
139+
result
96140
}
97141

98142
/** Estimate the cost of a Comet protobuf expression */
99143
private def estimateCometExpressionCost(expr: ExprOuterClass.Expr): Double = {
100-
expr.getExprStructCase match {
144+
val result = expr.getExprStructCase match {
101145
// Handle specialized expression types
102-
case ExprOuterClass.Expr.ExprStructCase.SUBSTRING => 6.3
146+
case ExprOuterClass.Expr.ExprStructCase.SUBSTRING =>
147+
// scalastyle:off println
148+
println(s"[CostModel] Expression: SUBSTRING -> 6.3")
149+
// scalastyle:on println
150+
6.3
103151

104152
// Handle generic scalar functions
105153
case ExprOuterClass.Expr.ExprStructCase.SCALARFUNC =>
106-
expr.getScalarFunc.getFunc match {
154+
val funcName = expr.getScalarFunc.getFunc
155+
val cost = funcName match {
107156
// String expression numbers from CometStringExpressionBenchmark
108157
case "ascii" => 0.6
109158
case "octet_length" => 0.6
@@ -122,9 +171,20 @@ class DefaultCometCostModel extends CometCostModel {
122171
case "translate" => 0.8
123172
case _ => defaultAcceleration
124173
}
174+
// scalastyle:off println
175+
println(s"[CostModel] Expression: SCALARFUNC($funcName) -> $cost")
176+
// scalastyle:on println
177+
cost
125178

126-
case _ => defaultAcceleration
179+
case _ =>
180+
// scalastyle:off println
181+
println(
182+
s"[CostModel] Expression: Unknown type ${expr.getExprStructCase} -> " +
183+
s"$defaultAcceleration")
184+
// scalastyle:on println
185+
defaultAcceleration
127186
}
187+
result
128188
}
129189

130190
}

spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala

Lines changed: 108 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,13 @@ class CometCostModelSuite extends CometTestBase {
4040
withCBOEnabled {
4141
withTempView("test_data") {
4242
createSimpleTestData()
43-
// Use a subquery to prevent projection pushdown
43+
// Create a more complex query that will trigger AQE with joins/aggregations
4444
val query = """
45-
SELECT length(upper_text1) as len1, length(upper_text2) as len2
46-
FROM (SELECT upper(text1) as upper_text1, upper(text2) as upper_text2 FROM test_data)
45+
SELECT t1.len1, t2.len2, COUNT(*) as cnt
46+
FROM (SELECT length(text1) as len1, text1 FROM test_data) t1
47+
JOIN (SELECT length(text2) as len2, text2 FROM test_data) t2
48+
ON t1.text1 = t2.text2
49+
GROUP BY t1.len1, t2.len2
4750
"""
4851

4952
executeAndCheckOperator(
@@ -58,10 +61,13 @@ class CometCostModelSuite extends CometTestBase {
5861
withCBOEnabled {
5962
withTempView("test_data") {
6063
createPaddedTestData()
61-
// Use a subquery to prevent projection pushdown
64+
// Create a more complex query that will trigger AQE with joins/aggregations
6265
val query = """
63-
SELECT trim(padded_text1) as trimmed1, trim(padded_text2) as trimmed2
64-
FROM (SELECT text1 as padded_text1, text2 as padded_text2 FROM test_data)
66+
SELECT t1.trimmed1, t2.trimmed2, COUNT(*) as cnt
67+
FROM (SELECT trim(text1) as trimmed1, text1 FROM test_data) t1
68+
JOIN (SELECT trim(text2) as trimmed2, text2 FROM test_data) t2
69+
ON t1.text1 = t2.text2
70+
GROUP BY t1.trimmed1, t2.trimmed2
6571
"""
6672

6773
executeAndCheckOperator(
@@ -76,10 +82,11 @@ class CometCostModelSuite extends CometTestBase {
7682
withCBODisabled {
7783
withTempView("test_data") {
7884
createPaddedTestData()
79-
// Use a subquery to prevent projection pushdown
85+
// Complex query without CBO
8086
val query = """
81-
SELECT trim(padded_text1) as trimmed1
82-
FROM (SELECT text1 as padded_text1 FROM test_data)
87+
SELECT trim(text1) as trimmed1, COUNT(*) as cnt
88+
FROM test_data
89+
GROUP BY trim(text1)
8390
"""
8491

8592
executeAndCheckOperator(
@@ -94,11 +101,12 @@ class CometCostModelSuite extends CometTestBase {
94101
withCBOEnabled {
95102
withTempView("test_data") {
96103
createSimpleTestData()
97-
// Query mixing fast (length: 9.1x) and slow (ascii: 0.6x) expressions with subquery
104+
// Query mixing fast (length: 9.1x) and slow (ascii: 0.6x) expressions with aggregation
98105
// Average acceleration: (9.1 + 0.6) / 2 = 4.85x -> cost = 0.206 (still prefer Comet)
99106
val query = """
100-
SELECT length(base_text) as fast_expr, ascii(base_text) as slow_expr
101-
FROM (SELECT text1 as base_text FROM test_data)
107+
SELECT length(text1) as fast_expr, ascii(text1) as slow_expr, COUNT(*) as cnt
108+
FROM test_data
109+
GROUP BY length(text1), ascii(text1)
102110
"""
103111

104112
executeAndCheckOperator(
@@ -117,7 +125,29 @@ class CometCostModelSuite extends CometTestBase {
117125
CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true",
118126
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
119127
CometConf.COMET_EXEC_ENABLED.key -> "true",
120-
CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true") {
128+
CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true",
129+
CometConf.COMET_EXEC_AGGREGATE_ENABLED.key -> "true", // Enable aggregation for GROUP BY
130+
CometConf.COMET_EXEC_HASH_JOIN_ENABLED.key -> "true", // Enable joins
131+
// Manually set the custom cost evaluator since plugin might not be loaded
132+
"spark.sql.adaptive.customCostEvaluatorClass" -> "org.apache.comet.cost.CometCostEvaluator",
133+
// Lower AQE thresholds to ensure it triggers on small test data
134+
"spark.sql.adaptive.advisoryPartitionSizeInBytes" -> "1KB",
135+
"spark.sql.adaptive.coalescePartitions.minPartitionSize" -> "1B") {
136+
137+
println(s"\n=== CBO Configuration ===")
138+
println(s"COMET_ENABLED: ${spark.conf.get(CometConf.COMET_ENABLED.key)}")
139+
println(s"COMET_COST_BASED_OPTIMIZATION_ENABLED: ${spark.conf.get(
140+
CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key)}")
141+
println(
142+
s"ADAPTIVE_EXECUTION_ENABLED: ${spark.conf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key)}")
143+
println(s"COMET_EXEC_ENABLED: ${spark.conf.get(CometConf.COMET_EXEC_ENABLED.key)}")
144+
println(
145+
s"COMET_EXEC_PROJECT_ENABLED: ${spark.conf.get(CometConf.COMET_EXEC_PROJECT_ENABLED.key)}")
146+
147+
// Check if custom cost evaluator is set
148+
val costEvaluator = spark.conf.getOption("spark.sql.adaptive.customCostEvaluatorClass")
149+
println(s"Custom cost evaluator: ${costEvaluator.getOrElse("None")}")
150+
121151
f
122152
}
123153
}
@@ -172,30 +202,89 @@ class CometCostModelSuite extends CometTestBase {
172202
query: String,
173203
expectedClass: Class[_],
174204
message: String): Unit = {
205+
206+
println(s"\n=== Executing Query ===")
207+
println(s"Query: $query")
208+
println(s"Expected class: ${expectedClass.getSimpleName}")
209+
175210
val result = sql(query)
176-
result.collect() // Materialize the plan
177211

178-
val executedPlan = stripAQEPlan(result.queryExecution.executedPlan)
212+
println(s"\n=== Pre-execution Plans ===")
213+
println("Logical Plan:")
214+
println(result.queryExecution.logical)
215+
println("\nOptimized Plan:")
216+
println(result.queryExecution.optimizedPlan)
217+
println("\nSpark Plan:")
218+
println(result.queryExecution.sparkPlan)
179219

180-
// scalastyle:off
220+
result.collect() // Materialize the plan
221+
222+
println(s"\n=== Post-execution Plans ===")
223+
println("Executed Plan (with AQE wrappers):")
181224
println(result.queryExecution.executedPlan)
225+
226+
val executedPlan = stripAQEPlan(result.queryExecution.executedPlan)
227+
println("\nExecuted Plan (stripped AQE):")
182228
println(executedPlan)
183229

230+
// Enhanced debugging: show complete plan tree structure
231+
println("\n=== Plan Tree Analysis ===")
232+
debugPlanTree(executedPlan, 0)
233+
184234
val hasProjectExec = findProjectExec(executedPlan)
185235

236+
println(s"\n=== Project Analysis ===")
237+
println(s"Found project exec: ${hasProjectExec.isDefined}")
238+
if (hasProjectExec.isDefined) {
239+
println(s"Actual class: ${hasProjectExec.get.getClass.getSimpleName}")
240+
println(s"Expected class: ${expectedClass.getSimpleName}")
241+
println(s"Is expected type: ${expectedClass.isInstance(hasProjectExec.get)}")
242+
}
243+
186244
assert(hasProjectExec.isDefined, "Should have a project operator")
187245
assert(
188246
expectedClass.isInstance(hasProjectExec.get),
189247
s"$message, got ${hasProjectExec.get.getClass.getSimpleName}")
248+
249+
println(s"=== Test PASSED ===\n")
190250
}
191251

192252
/** Helper method to find ProjectExec or CometProjectExec in the plan tree */
193253
private def findProjectExec(plan: SparkPlan): Option[SparkPlan] = {
254+
// More robust recursive search that handles deep nesting
255+
def searchPlan(node: SparkPlan): Option[SparkPlan] = {
256+
println(s"[findProjectExec] Checking node: ${node.getClass.getSimpleName}")
257+
258+
if (node.isInstanceOf[ProjectExec] || node.isInstanceOf[CometProjectExec]) {
259+
println(s"[findProjectExec] Found project operator: ${node.getClass.getSimpleName}")
260+
Some(node)
261+
} else {
262+
// Search all children recursively
263+
for (child <- node.children) {
264+
searchPlan(child) match {
265+
case Some(found) => return Some(found)
266+
case None => // continue searching
267+
}
268+
}
269+
None
270+
}
271+
}
272+
273+
searchPlan(plan)
274+
}
275+
276+
/** Debug method to print complete plan tree structure */
277+
private def debugPlanTree(plan: SparkPlan, depth: Int): Unit = {
278+
val indent = " " * depth
279+
println(s"$indent${plan.getClass.getSimpleName}")
280+
281+
// Also show if this is a project operator
194282
if (plan.isInstanceOf[ProjectExec] || plan.isInstanceOf[CometProjectExec]) {
195-
Some(plan)
196-
} else {
197-
plan.children.flatMap(findProjectExec).headOption
283+
println(s"$indent -> PROJECT OPERATOR FOUND!")
198284
}
285+
286+
// Recursively print children
287+
plan.children.foreach(child => debugPlanTree(child, depth + 1))
199288
}
200289

201290
test("Direct cost model test - fast vs slow expressions") {

0 commit comments

Comments
 (0)