Skip to content

Commit a7acf9a

Browse files
committed
use configs
1 parent 304e9f7 commit a7acf9a

File tree

2 files changed

+38
-11
lines changed

2 files changed

+38
-11
lines changed

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ object CometConf extends ShimCometConf {
754754
.booleanConf
755755
.createWithEnvVarOrDefault("ENABLE_COMET_STRICT_TESTING", false)
756756

757-
val COMET_COST_BASED_OPTIMIZATION_ENABLED: ConfigEntry[Boolean] =
757+
val `COMET_COST_BASED_OPTIMIZATION_ENABLED`: ConfigEntry[Boolean] =
758758
conf("spark.comet.cost.enabled")
759759
.category(CATEGORY_TUNING)
760760
.doc(

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ import org.apache.spark.sql.internal.SQLConf
4747
import org.apache.spark.sql.types._
4848

4949
import org.apache.comet.{CometConf, CometExplainInfo, ExtendedExplainInfo}
50-
import org.apache.comet.CometConf.{COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST}
50+
import org.apache.comet.CometConf.{COMET_COST_BASED_OPTIMIZATION_ENABLED, COMET_COST_MODEL_CLASS, COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST}
5151
import org.apache.comet.CometSparkSessionExtensions._
52-
import org.apache.comet.cost.DefaultCometCostModel
52+
import org.apache.comet.cost.CometCostModel
5353
import org.apache.comet.rules.CometExecRule.allExecs
5454
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, Unsupported}
5555
import org.apache.comet.serde.operator._
@@ -98,6 +98,28 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
9898

9999
private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get()
100100

101+
// Cache the cost model to avoid loading the class on every call
102+
@transient private lazy val costModel: Option[CometCostModel] = {
103+
if (COMET_COST_BASED_OPTIMIZATION_ENABLED.get(conf)) {
104+
try {
105+
val costModelClassName = COMET_COST_MODEL_CLASS.get(conf)
106+
// scalastyle:off classforname
107+
val costModelClass = Class.forName(costModelClassName)
108+
// scalastyle:on classforname
109+
val constructor = costModelClass.getConstructor()
110+
Some(constructor.newInstance().asInstanceOf[CometCostModel])
111+
} catch {
112+
case e: Exception =>
113+
logWarning(
114+
s"Failed to load cost model class: ${e.getMessage}. " +
115+
"Falling back to Spark query plan without cost-based optimization.")
116+
None
117+
}
118+
} else {
119+
None
120+
}
121+
}
122+
101123
private def applyCometShuffle(plan: SparkPlan): SparkPlan = {
102124
plan.transformUp {
103125
case s: ShuffleExchangeExec if CometShuffleExchangeExec.nativeShuffleSupported(s) =>
@@ -347,15 +369,20 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
347369
override def apply(plan: SparkPlan): SparkPlan = {
348370
val candidatePlan = _apply(plan)
349371

350-
// TODO load cost model via config and reflection
351-
val costModel = new DefaultCometCostModel
352-
val costBefore = costModel.estimateCost(plan)
353-
val costAfter = costModel.estimateCost(candidatePlan)
372+
// Only apply cost-based optimization if enabled and cost model is available
373+
val newPlan = costModel match {
374+
case Some(model) =>
375+
val costBefore = model.estimateCost(plan)
376+
val costAfter = model.estimateCost(candidatePlan)
354377

355-
val newPlan = if (costAfter.acceleration > costBefore.acceleration) {
356-
candidatePlan
357-
} else {
358-
plan
378+
if (costAfter.acceleration > costBefore.acceleration) {
379+
candidatePlan
380+
} else {
381+
plan
382+
}
383+
case None =>
384+
// Cost-based optimization is disabled or failed to load, return candidate plan
385+
candidatePlan
359386
}
360387

361388
if (showTransformations && !newPlan.fastEquals(plan)) {

0 commit comments

Comments
 (0)