Skip to content

Commit 9aa04b1

Browse files
Try refactoring default cost model to make it faster
1 parent 3732d28 commit 9aa04b1

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

python/egglog/egraph.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2015,21 +2015,23 @@ def default_cost_model(egraph: EGraph, expr: BaseExpr, children_costs: list[int]
20152015
"""
20162016
A default cost model for an e-graph, which looks up costs set on function calls, or uses 1 as the default cost.
20172017
"""
2018-
from .builtins import Container, i64 # noqa: PLC0415
2018+
from .builtins import Container # noqa: PLC0415
20192019
from .deconstruct import get_callable_fn # noqa: PLC0415
20202020

2021-
# By default, all nodes have a cost of 1 except for containers which have a cost of 0
2022-
self_cost = 0 if isinstance(expr, Container) else 1
2023-
if (callable_fn := get_callable_fn(expr)) is not None:
2024-
# If this is a callable function with a set cost override the self cost
2025-
match get_callable_cost(callable_fn):
2026-
case int(self_cost):
2027-
pass
2028-
# If we have set the cost manually for this experession, use that instead
2029-
if egraph.has_custom_cost(callable_fn):
2030-
match egraph.lookup_function_value(get_cost(expr)):
2031-
case i64(i):
2032-
self_cost = i
2021+
# 1. First prefer if the expr has a custom cost set on it
2022+
if (
2023+
(callable_fn := get_callable_fn(expr)) is not None
2024+
and egraph.has_custom_cost(callable_fn)
2025+
and (i := egraph.lookup_function_value(get_cost(expr))) is not None
2026+
):
2027+
self_cost = int(i)
2028+
# 2. Else, check if this is a callable and it has a cost set on its declaration
2029+
elif callable_fn is not None and (callable_cost := get_callable_cost(callable_fn)) is not None:
2030+
self_cost = callable_cost
2031+
# 3. Else, if this is a container, it has no cost, otherwise it has a cost of 1
2032+
else:
2033+
# By default, all nodes have a cost of 1 except for containers which have a cost of 0
2034+
self_cost = 0 if isinstance(expr, Container) else 1
20332035
# Sum up the costs of the children and our own cost
20342036
return sum(children_costs, start=self_cost)
20352037

0 commit comments

Comments
 (0)