@@ -2015,21 +2015,23 @@ def default_cost_model(egraph: EGraph, expr: BaseExpr, children_costs: list[int]
20152015 """
20162016 A default cost model for an e-graph, which looks up costs set on function calls, or uses 1 as the default cost.
20172017 """
2018- from .builtins import Container , i64 # noqa: PLC0415
2018+ from .builtins import Container # noqa: PLC0415
20192019 from .deconstruct import get_callable_fn # noqa: PLC0415
20202020
2021- # By default, all nodes have a cost of 1 except for containers which have a cost of 0
2022- self_cost = 0 if isinstance (expr , Container ) else 1
2023- if (callable_fn := get_callable_fn (expr )) is not None :
2024- # If this is a callable function with a set cost override the self cost
2025- match get_callable_cost (callable_fn ):
2026- case int (self_cost ):
2027- pass
2028- # If we have set the cost manually for this experession, use that instead
2029- if egraph .has_custom_cost (callable_fn ):
2030- match egraph .lookup_function_value (get_cost (expr )):
2031- case i64 (i ):
2032- self_cost = i
2021+ # 1. First prefer if the expr has a custom cost set on it
2022+ if (
2023+ (callable_fn := get_callable_fn (expr )) is not None
2024+ and egraph .has_custom_cost (callable_fn )
2025+ and (i := egraph .lookup_function_value (get_cost (expr ))) is not None
2026+ ):
2027+ self_cost = int (i )
2028+ # 2. Else, check if this is a callable and it has a cost set on its declaration
2029+ elif callable_fn is not None and (callable_cost := get_callable_cost (callable_fn )) is not None :
2030+ self_cost = callable_cost
2031+ # 3. Else, if this is a container, it has no cost, otherwise it has a cost of 1
2032+ else :
2033+ # By default, all nodes have a cost of 1 except for containers which have a cost of 0
2034+ self_cost = 0 if isinstance (expr , Container ) else 1
20332035 # Sum up the costs of the children and our own cost
20342036 return sum (children_costs , start = self_cost )
20352037
0 commit comments