@@ -40,10 +40,13 @@ class CometCostModelSuite extends CometTestBase {
4040 withCBOEnabled {
4141 withTempView(" test_data" ) {
4242 createSimpleTestData()
43- // Use a subquery to prevent projection pushdown
43+ // Create a more complex query that will trigger AQE with joins/aggregations
4444 val query = """
45- SELECT length(upper_text1) as len1, length(upper_text2) as len2
46- FROM (SELECT upper(text1) as upper_text1, upper(text2) as upper_text2 FROM test_data)
45+ SELECT t1.len1, t2.len2, COUNT(*) as cnt
46+ FROM (SELECT length(text1) as len1, text1 FROM test_data) t1
47+ JOIN (SELECT length(text2) as len2, text2 FROM test_data) t2
48+ ON t1.text1 = t2.text2
49+ GROUP BY t1.len1, t2.len2
4750 """
4851
4952 executeAndCheckOperator(
@@ -58,10 +61,13 @@ class CometCostModelSuite extends CometTestBase {
5861 withCBOEnabled {
5962 withTempView(" test_data" ) {
6063 createPaddedTestData()
61- // Use a subquery to prevent projection pushdown
64+ // Create a more complex query that will trigger AQE with joins/aggregations
6265 val query = """
63- SELECT trim(padded_text1) as trimmed1, trim(padded_text2) as trimmed2
64- FROM (SELECT text1 as padded_text1, text2 as padded_text2 FROM test_data)
66+ SELECT t1.trimmed1, t2.trimmed2, COUNT(*) as cnt
67+ FROM (SELECT trim(text1) as trimmed1, text1 FROM test_data) t1
68+ JOIN (SELECT trim(text2) as trimmed2, text2 FROM test_data) t2
69+ ON t1.text1 = t2.text2
70+ GROUP BY t1.trimmed1, t2.trimmed2
6571 """
6672
6773 executeAndCheckOperator(
@@ -76,10 +82,11 @@ class CometCostModelSuite extends CometTestBase {
7682 withCBODisabled {
7783 withTempView(" test_data" ) {
7884 createPaddedTestData()
79- // Use a subquery to prevent projection pushdown
85+ // Complex query without CBO
8086 val query = """
81- SELECT trim(padded_text1) as trimmed1
82- FROM (SELECT text1 as padded_text1 FROM test_data)
87+ SELECT trim(text1) as trimmed1, COUNT(*) as cnt
88+ FROM test_data
89+ GROUP BY trim(text1)
8390 """
8491
8592 executeAndCheckOperator(
@@ -94,11 +101,12 @@ class CometCostModelSuite extends CometTestBase {
94101 withCBOEnabled {
95102 withTempView(" test_data" ) {
96103 createSimpleTestData()
97- // Query mixing fast (length: 9.1x) and slow (ascii: 0.6x) expressions with subquery
104+ // Query mixing fast (length: 9.1x) and slow (ascii: 0.6x) expressions with aggregation
98105 // Average acceleration: (9.1 + 0.6) / 2 = 4.85x -> cost = 0.206 (still prefer Comet)
99106 val query = """
100- SELECT length(base_text) as fast_expr, ascii(base_text) as slow_expr
101- FROM (SELECT text1 as base_text FROM test_data)
107+ SELECT length(text1) as fast_expr, ascii(text1) as slow_expr, COUNT(*) as cnt
108+ FROM test_data
109+ GROUP BY length(text1), ascii(text1)
102110 """
103111
104112 executeAndCheckOperator(
@@ -117,7 +125,29 @@ class CometCostModelSuite extends CometTestBase {
117125 CometConf .COMET_COST_BASED_OPTIMIZATION_ENABLED .key -> " true" ,
118126 SQLConf .ADAPTIVE_EXECUTION_ENABLED .key -> " true" ,
119127 CometConf .COMET_EXEC_ENABLED .key -> " true" ,
120- CometConf .COMET_EXEC_PROJECT_ENABLED .key -> " true" ) {
128+ CometConf .COMET_EXEC_PROJECT_ENABLED .key -> " true" ,
129+ CometConf .COMET_EXEC_AGGREGATE_ENABLED .key -> " true" , // Enable aggregation for GROUP BY
130+ CometConf .COMET_EXEC_HASH_JOIN_ENABLED .key -> " true" , // Enable joins
131+ // Manually set the custom cost evaluator since plugin might not be loaded
132+ " spark.sql.adaptive.customCostEvaluatorClass" -> " org.apache.comet.cost.CometCostEvaluator" ,
133+ // Lower AQE thresholds to ensure it triggers on small test data
134+ " spark.sql.adaptive.advisoryPartitionSizeInBytes" -> " 1KB" ,
135+ " spark.sql.adaptive.coalescePartitions.minPartitionSize" -> " 1B" ) {
136+
137+ println(s " \n === CBO Configuration === " )
138+ println(s " COMET_ENABLED: ${spark.conf.get(CometConf .COMET_ENABLED .key)}" )
139+ println(s " COMET_COST_BASED_OPTIMIZATION_ENABLED: ${spark.conf.get(
140+ CometConf .COMET_COST_BASED_OPTIMIZATION_ENABLED .key)}" )
141+ println(
142+ s " ADAPTIVE_EXECUTION_ENABLED: ${spark.conf.get(SQLConf .ADAPTIVE_EXECUTION_ENABLED .key)}" )
143+ println(s " COMET_EXEC_ENABLED: ${spark.conf.get(CometConf .COMET_EXEC_ENABLED .key)}" )
144+ println(
145+ s " COMET_EXEC_PROJECT_ENABLED: ${spark.conf.get(CometConf .COMET_EXEC_PROJECT_ENABLED .key)}" )
146+
147+ // Check if custom cost evaluator is set
148+ val costEvaluator = spark.conf.getOption(" spark.sql.adaptive.customCostEvaluatorClass" )
149+ println(s " Custom cost evaluator: ${costEvaluator.getOrElse(" None" )}" )
150+
121151 f
122152 }
123153 }
@@ -172,30 +202,89 @@ class CometCostModelSuite extends CometTestBase {
172202 query : String ,
173203 expectedClass : Class [_],
174204 message : String ): Unit = {
205+
206+ println(s " \n === Executing Query === " )
207+ println(s " Query: $query" )
208+ println(s " Expected class: ${expectedClass.getSimpleName}" )
209+
175210 val result = sql(query)
176- result.collect() // Materialize the plan
177211
178- val executedPlan = stripAQEPlan(result.queryExecution.executedPlan)
212+ println(s " \n === Pre-execution Plans === " )
213+ println(" Logical Plan:" )
214+ println(result.queryExecution.logical)
215+ println(" \n Optimized Plan:" )
216+ println(result.queryExecution.optimizedPlan)
217+ println(" \n Spark Plan:" )
218+ println(result.queryExecution.sparkPlan)
179219
180- // scalastyle:off
220+ result.collect() // Materialize the plan
221+
222+ println(s " \n === Post-execution Plans === " )
223+ println(" Executed Plan (with AQE wrappers):" )
181224 println(result.queryExecution.executedPlan)
225+
226+ val executedPlan = stripAQEPlan(result.queryExecution.executedPlan)
227+ println(" \n Executed Plan (stripped AQE):" )
182228 println(executedPlan)
183229
230+ // Enhanced debugging: show complete plan tree structure
231+ println(" \n === Plan Tree Analysis ===" )
232+ debugPlanTree(executedPlan, 0 )
233+
184234 val hasProjectExec = findProjectExec(executedPlan)
185235
236+ println(s " \n === Project Analysis === " )
237+ println(s " Found project exec: ${hasProjectExec.isDefined}" )
238+ if (hasProjectExec.isDefined) {
239+ println(s " Actual class: ${hasProjectExec.get.getClass.getSimpleName}" )
240+ println(s " Expected class: ${expectedClass.getSimpleName}" )
241+ println(s " Is expected type: ${expectedClass.isInstance(hasProjectExec.get)}" )
242+ }
243+
186244 assert(hasProjectExec.isDefined, " Should have a project operator" )
187245 assert(
188246 expectedClass.isInstance(hasProjectExec.get),
189247 s " $message, got ${hasProjectExec.get.getClass.getSimpleName}" )
248+
249+ println(s " === Test PASSED === \n " )
190250 }
191251
192252 /** Helper method to find ProjectExec or CometProjectExec in the plan tree */
193253 private def findProjectExec (plan : SparkPlan ): Option [SparkPlan ] = {
254+ // More robust recursive search that handles deep nesting
255+ def searchPlan (node : SparkPlan ): Option [SparkPlan ] = {
256+ println(s " [findProjectExec] Checking node: ${node.getClass.getSimpleName}" )
257+
258+ if (node.isInstanceOf [ProjectExec ] || node.isInstanceOf [CometProjectExec ]) {
259+ println(s " [findProjectExec] Found project operator: ${node.getClass.getSimpleName}" )
260+ Some (node)
261+ } else {
262+ // Search all children recursively
263+ for (child <- node.children) {
264+ searchPlan(child) match {
265+ case Some (found) => return Some (found)
266+ case None => // continue searching
267+ }
268+ }
269+ None
270+ }
271+ }
272+
273+ searchPlan(plan)
274+ }
275+
276+ /** Debug method to print complete plan tree structure */
277+ private def debugPlanTree (plan : SparkPlan , depth : Int ): Unit = {
278+ val indent = " " * depth
279+ println(s " $indent${plan.getClass.getSimpleName}" )
280+
281+ // Also show if this is a project operator
194282 if (plan.isInstanceOf [ProjectExec ] || plan.isInstanceOf [CometProjectExec ]) {
195- Some (plan)
196- } else {
197- plan.children.flatMap(findProjectExec).headOption
283+ println(s " $indent -> PROJECT OPERATOR FOUND! " )
198284 }
285+
286+ // Recursively print children
287+ plan.children.foreach(child => debugPlanTree(child, depth + 1 ))
199288 }
200289
201290 test(" Direct cost model test - fast vs slow expressions" ) {
0 commit comments