Skip to content

Commit 05934d9

Browse files
Change cost model to just be function
1 parent 7a615e9 commit 05934d9

File tree

7 files changed

+154
-305
lines changed

7 files changed

+154
-305
lines changed

docs/reference/python-integration.md

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -638,47 +638,30 @@ egraph.saturate(r)
638638

639639
## Custom Cost Models
640640

641-
Custom cost models are also supported by subclassing `CostModel[T]` or `DefaultCostModel` and passing in an instance as the `cost_model` kwargs to `EGraph.extract`. The `CostModel` is paramterized by a cost type `T`, which in the `DefaultCostModel` is `int`. Any cost must be able to be compared to choose the lowest cost.
641+
By default, when extracting from the e-graph, we use a simple cost model, that looks at the costs assigned to each
642+
function and any custom costs set with `set_cost`, and finds the lowest cost expression looking at the total tree size.
642643

643-
The `Expr`s passed to your cost model represent partial program trees. Any builtin values (containers or single primitives) will be fully evaluated, but any values that return user defined classes will be last as opaque "values",
644-
representing an e-class in the e-graph. The only thing you can do with values is to compare them to each other or
645-
use them in `EGraph.lookup_function_value` to lookup the resulting value of a call with values in it.
644+
Custom cost models are also supported, which can be passed into `extract` as the `cost_model` keyword argument. They
645+
are defined as functions followed the `CostModel` protocol, that take in an e-graph, an expression, and the costs of the children, and return the total cost of that expression. Costs don't have to be integers, they can be any type that supports comparison.
646646

647-
For example, here is a cost model that uses boolean values to determine if a model is extractable or not:
647+
There are a few builtin cost models:
648+
649+
- `default_cost_model`: The default cost model, which uses integer costs and sums them up.
650+
- `greedy_dag_cost_model(inner_cost_model=default_cost_model)`: A cost model which uses a greedy DAG algorithm to find the lowest cost expression, allowing for shared sub-expressions. It takes in another cost model to use for the base costs of each expression.
651+
652+
Note that when passed into your cost model, the expression won't be a full tree. Instead, only the top level call be present, and all of it's arguments will be opaque "value" expressions, representing e-classes in the e-graph. You can't do much with them except use them to construct other expression to pass into `egraph.lookup_function_value` to get the resulting value of a call with those arguments. The only exception is all builtin types, like ints, vecs, strings, etc. will be fully evaluated recursively, so they can be matched against.
653+
654+
For example, here is a cost model that has a boolean cost if the value is even or not:
648655

649656
```{code-cell} python
657+
def is_even_cost_model(egraph: EGraph, expr: Expr, children_costs: list[bool]) -> bool:
658+
from egglog import i64 # noqa: PLC0415
650659
651-
class MyExpr(Expr):
652-
def __init__(self) -> None: ...
653-
654-
655-
class BooleanCostModel(CostModel[bool]):
656-
cost_tp = bool
657-
658-
def primitive_cost(self, egraph: EGraph, value: Primitive) -> bool:
659-
# Only allow extracting even integers
660-
match value:
661-
case i64(i) if i % 2 == 0:
662-
return True
663-
return False
664-
665-
def container_cost(self, egraph: EGraph, container: Expr, children_costs: list[bool]) -> bool:
666-
# Only allow extracting Vecs of extractable values
667-
match container:
668-
case Vec():
669-
return all(children_costs)
670-
return False
671-
672-
def call_cost(self, egraph: EGraph, expr: Expr) -> bool:
673-
# Only allow extracting calls to `my_f`
674-
match expr:
675-
case my_f():
676-
return True
677-
return False
678-
679-
def fold(self, callable: ExprCallable, children_costs: list[bool], head_cost: bool) -> bool:
680-
# Only allow extracting calls where the head and all children are extractable
681-
return head_cost and all(children_costs)
682-
683-
assert EGraph().extract(i64(10), include_cost=True, cost_model=BooleanCostModel()) == (i64(10), True)
660+
match expr:
661+
case i64(i):
662+
return i % 2 == 0
663+
return False
664+
assert EGraph().extract(i64(10), include_cost=True, cost_model=is_even_cost_model) == (i64(10), True)
665+
666+
assert EGraph().extract(i64(5), include_cost=True, cost_model=is_even_cost_model) == (i64(5), False)
684667
```

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ test = [
5454
"egglog[array]",
5555
"pytest-codspeed",
5656
"pytest-benchmark",
57-
"pytest-xdist",
58-
"pytest-mock",
57+
"pytest-xdist"
5958
]
6059

6160
docs = [

0 commit comments

Comments
 (0)