@@ -47,9 +47,9 @@ import org.apache.spark.sql.internal.SQLConf
4747import org .apache .spark .sql .types ._
4848
4949import org .apache .comet .{CometConf , CometExplainInfo , ExtendedExplainInfo }
50- import org .apache .comet .CometConf .{COMET_SPARK_TO_ARROW_ENABLED , COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST }
50+ import org .apache .comet .CometConf .{COMET_COST_BASED_OPTIMIZATION_ENABLED , COMET_COST_MODEL_CLASS , COMET_SPARK_TO_ARROW_ENABLED , COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST }
5151import org .apache .comet .CometSparkSessionExtensions ._
52- import org .apache .comet .cost .DefaultCometCostModel
52+ import org .apache .comet .cost .CometCostModel
5353import org .apache .comet .rules .CometExecRule .allExecs
5454import org .apache .comet .serde .{CometOperatorSerde , Compatible , Incompatible , OperatorOuterClass , Unsupported }
5555import org .apache .comet .serde .operator ._
@@ -98,6 +98,28 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
9898
9999 private lazy val showTransformations = CometConf .COMET_EXPLAIN_TRANSFORMATIONS .get()
100100
101+ // Cache the cost model to avoid loading the class on every call
102+ @ transient private lazy val costModel : Option [CometCostModel ] = {
103+ if (COMET_COST_BASED_OPTIMIZATION_ENABLED .get(conf)) {
104+ try {
105+ val costModelClassName = COMET_COST_MODEL_CLASS .get(conf)
106+ // scalastyle:off classforname
107+ val costModelClass = Class .forName(costModelClassName)
108+ // scalastyle:on classforname
109+ val constructor = costModelClass.getConstructor()
110+ Some (constructor.newInstance().asInstanceOf [CometCostModel ])
111+ } catch {
112+ case e : Exception =>
113+ logWarning(
114+ s " Failed to load cost model class: ${e.getMessage}. " +
115+ " Falling back to Spark query plan without cost-based optimization." )
116+ None
117+ }
118+ } else {
119+ None
120+ }
121+ }
122+
101123 private def applyCometShuffle (plan : SparkPlan ): SparkPlan = {
102124 plan.transformUp {
103125 case s : ShuffleExchangeExec if CometShuffleExchangeExec .nativeShuffleSupported(s) =>
@@ -347,15 +369,20 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
347369 override def apply (plan : SparkPlan ): SparkPlan = {
348370 val candidatePlan = _apply(plan)
349371
350- // TODO load cost model via config and reflection
351- val costModel = new DefaultCometCostModel
352- val costBefore = costModel.estimateCost(plan)
353- val costAfter = costModel.estimateCost(candidatePlan)
372+ // Only apply cost-based optimization if enabled and cost model is available
373+ val newPlan = costModel match {
374+ case Some (model) =>
375+ val costBefore = model.estimateCost(plan)
376+ val costAfter = model.estimateCost(candidatePlan)
354377
355- val newPlan = if (costAfter.acceleration > costBefore.acceleration) {
356- candidatePlan
357- } else {
358- plan
378+ if (costAfter.acceleration > costBefore.acceleration) {
379+ candidatePlan
380+ } else {
381+ plan
382+ }
383+ case None =>
384+ // Cost-based optimization is disabled or failed to load, return candidate plan
385+ candidatePlan
359386 }
360387
361388 if (showTransformations && ! newPlan.fastEquals(plan)) {
0 commit comments