Skip to content

Commit 61d5395

Browse files
Add custom cost model and ability to get costs
1 parent db84e84 commit 61d5395

File tree

16 files changed

+1226
-43
lines changed

16 files changed

+1226
-43
lines changed

docs/reference/python-integration.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,3 +635,50 @@ r = ruleset(
635635
)
636636
egraph.saturate(r)
637637
```
638+
639+
## Custom Cost Models
640+
641+
Custom cost models are also supported by subclassing `CostModel[T]` or `DefaultCostModel` and passing in an instance as the `cost_model` kwargs to `EGraph.extract`. The `CostModel` is paramterized by a cost type `T`, which in the `DefaultCostModel` is `int`. Any cost must be able to be compared to choose the lowest cost.
642+
643+
The `Expr`s passed to your cost model represent partial program trees. Any builtin values (containers or single primitives) will be fully evaluated, but any values that return user defined classes will be last as opaque "values",
644+
representing an e-class in the e-graph. The only thing you can do with values is to compare them to each other or
645+
use them in `EGraph.lookup_function_value` to lookup the resulting value of a call with values in it.
646+
647+
For example, here is a cost model that uses boolean values to determine if a model is extractable or not:
648+
649+
```{code-cell} python
650+
651+
class MyExpr(Expr):
652+
def __init__(self) -> None: ...
653+
654+
655+
class BooleanCostModel(CostModel[bool]):
656+
cost_tp = bool
657+
658+
def primitive_cost(self, egraph: EGraph, value: Primitive) -> bool:
659+
# Only allow extracting even integers
660+
match value:
661+
case i64(i) if i % 2 == 0:
662+
return True
663+
return False
664+
665+
def container_cost(self, egraph: EGraph, container: Expr, children_costs: list[bool]) -> bool:
666+
# Only allow extracting Vecs of extractable values
667+
match container:
668+
case Vec():
669+
return all(children_costs)
670+
return False
671+
672+
def call_cost(self, egraph: EGraph, expr: Expr) -> bool:
673+
# Only allow extracting calls to `my_f`
674+
match expr:
675+
case my_f():
676+
return True
677+
return False
678+
679+
def fold(self, callable: ExprCallable, children_costs: list[bool], head_cost: bool) -> bool:
680+
# Only allow extracting calls where the head and all children are extractable
681+
return head_cost and all(children_costs)
682+
683+
assert EGraph().extract(i64(10), include_cost=True, cost_model=BooleanCostModel()) == (i64(10), True)
684+
```

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ test = [
5555
"pytest-codspeed",
5656
"pytest-benchmark",
5757
"pytest-xdist",
58+
"pytest-mock",
5859
]
5960

6061
docs = [

python/egglog/bindings.pyi

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from collections.abc import Callable
12
from datetime import timedelta
3+
from fractions import Fraction
24
from pathlib import Path
3-
from typing import TypeAlias
5+
from typing import Generic, Protocol, TypeAlias, TypeVar
46

57
from typing_extensions import final
68

@@ -14,6 +16,7 @@ __all__ = [
1416
"Change",
1517
"Check",
1618
"Constructor",
19+
"CostModel",
1720
"Datatype",
1821
"Datatypes",
1922
"DefaultPrintFunctionMode",
@@ -26,6 +29,7 @@ __all__ = [
2629
"Extract",
2730
"ExtractBest",
2831
"ExtractVariants",
32+
"Extractor",
2933
"Fact",
3034
"Fail",
3135
"Float",
@@ -83,6 +87,7 @@ __all__ = [
8387
"UserDefined",
8488
"UserDefinedCommandOutput",
8589
"UserDefinedOutput",
90+
"Value",
8691
"Var",
8792
"Variant",
8893
]
@@ -128,6 +133,31 @@ class EGraph:
128133
max_calls_per_function: int | None = None,
129134
include_temporary_functions: bool = False,
130135
) -> SerializedEGraph: ...
136+
def lookup_function(self, name: str, key: list[Value]) -> Value | None: ...
137+
def eval_expr(self, expr: _Expr) -> tuple[str, Value]: ...
138+
def value_to_i64(self, v: Value) -> int: ...
139+
def value_to_f64(self, v: Value) -> float: ...
140+
def value_to_string(self, v: Value) -> str: ...
141+
def value_to_bool(self, v: Value) -> bool: ...
142+
def value_to_rational(self, v: Value) -> Fraction: ...
143+
def value_to_bigint(self, v: Value) -> int: ...
144+
def value_to_bigrat(self, v: Value) -> Fraction: ...
145+
def value_to_pyobject(self, py_object_sort: PyObjectSort, v: Value) -> object: ...
146+
def value_to_map(self, v: Value) -> dict[Value, Value]: ...
147+
def value_to_multiset(self, v: Value) -> list[Value]: ...
148+
def value_to_vec(self, v: Value) -> list[Value]: ...
149+
def value_to_function(self, v: Value) -> tuple[str, list[Value]]: ...
150+
def value_to_set(self, v: Value) -> set[Value]: ...
151+
# def dynamic_cost_model_enode_cost(self, func: str, args: list[Value]) -> int: ...
152+
153+
@final
154+
class Value:
155+
def __hash__(self) -> int: ...
156+
def __eq__(self, value: object) -> bool: ...
157+
def __lt__(self, other: object) -> bool: ...
158+
def __le__(self, other: object) -> bool: ...
159+
def __gt__(self, other: object) -> bool: ...
160+
def __ge__(self, other: object) -> bool: ...
131161

132162
@final
133163
class EggSmolError(Exception):
@@ -732,3 +762,33 @@ class TermDag:
732762
def expr_to_term(self, expr: _Expr) -> _Term: ...
733763
def term_to_expr(self, term: _Term, span: _Span) -> _Expr: ...
734764
def to_string(self, term: _Term) -> str: ...
765+
766+
##
767+
# Extraction
768+
##
769+
class _Cost(Protocol):
770+
def __lt__(self, other: _Cost) -> bool: ...
771+
def __le__(self, other: _Cost) -> bool: ...
772+
def __gt__(self, other: _Cost) -> bool: ...
773+
def __ge__(self, other: _Cost) -> bool: ...
774+
def __add__(self, other: _Cost) -> _Cost: ...
775+
776+
_COST = TypeVar("_COST", bound=_Cost)
777+
778+
@final
779+
class CostModel(Generic[_COST]):
780+
def __init__(
781+
self,
782+
fold: Callable[[str, _COST, list[_COST]], _COST] | None,
783+
enode_cost: Callable[[str, list[Value]], _COST] | None,
784+
container_cost: Callable[[str, Value, list[_COST]], _COST] | None,
785+
base_value_cost: Callable[[str, Value], _COST] | None,
786+
) -> None: ...
787+
788+
@final
789+
class Extractor(Generic[_COST]):
790+
def __init__(self, rootsorts: list[str] | None, egraph: EGraph, cost_model: CostModel[_COST]) -> None: ...
791+
def extract_best(self, egraph: EGraph, termdag: TermDag, value: Value, sort: str) -> tuple[_COST, _Term]: ...
792+
def extract_variants(
793+
self, egraph: EGraph, termdag: TermDag, value: Value, nvariants: int, sort: str
794+
) -> list[tuple[_COST, _Term]]: ...

python/egglog/builtins.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@
3333
"BigRatLike",
3434
"Bool",
3535
"BoolLike",
36+
"Container",
3637
"ExprValueError",
3738
"Map",
3839
"MapLike",
3940
"MultiSet",
41+
"Primitive",
4042
"PyObject",
4143
"Rational",
4244
"Set",
@@ -57,6 +59,9 @@
5759
"py_exec",
5860
]
5961

62+
Container: TypeAlias = "Map | Set | MultiSet | Vec | UnstableFn"
63+
Primitive: TypeAlias = "String | Bool | i64 | f64 | Rational | BigInt | BigRat | PyObject | Unit"
64+
6065

6166
@dataclass
6267
class ExprValueError(AttributeError):

python/egglog/declarations.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from typing_extensions import Self, assert_never
1616

17+
from .bindings import Value
18+
1719
if TYPE_CHECKING:
1820
from collections.abc import Callable, Iterable, Mapping
1921

@@ -49,6 +51,7 @@
4951
"FunctionDecl",
5052
"FunctionRef",
5153
"FunctionSignature",
54+
"GetCostDecl",
5255
"HasDeclerations",
5356
"InitRef",
5457
"JustTypeRef",
@@ -82,6 +85,7 @@
8285
"UnboundVarDecl",
8386
"UnionDecl",
8487
"UnnamedFunctionRef",
88+
"ValueDecl",
8589
"collect_unbound_vars",
8690
"replace_typed_expr",
8791
"upcast_declerations",
@@ -696,7 +700,20 @@ class PartialCallDecl:
696700
call: CallDecl
697701

698702

699-
ExprDecl: TypeAlias = UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
703+
@dataclass(frozen=True)
704+
class GetCostDecl:
705+
callable: CallableRef
706+
args: tuple[TypedExprDecl, ...]
707+
708+
709+
@dataclass(frozen=True)
710+
class ValueDecl:
711+
value: Value
712+
713+
714+
ExprDecl: TypeAlias = (
715+
UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl | ValueDecl | GetCostDecl
716+
)
700717

701718

702719
@dataclass(frozen=True)

python/egglog/deconstruct.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing_extensions import TypeVarTuple, Unpack
1212

1313
from .declarations import *
14-
from .egraph import BaseExpr
14+
from .egraph import BaseExpr, Expr
1515
from .runtime import *
1616
from .thunk import *
1717

@@ -49,7 +49,11 @@ def get_literal_value(x: PyObject) -> object: ...
4949
def get_literal_value(x: UnstableFn[T, Unpack[TS]]) -> Callable[[Unpack[TS]], T] | None: ...
5050

5151

52-
def get_literal_value(x: String | Bool | i64 | f64 | PyObject | UnstableFn) -> object:
52+
@overload
53+
def get_literal_value(x: Expr) -> None: ...
54+
55+
56+
def get_literal_value(x: object) -> object:
5357
"""
5458
Returns the literal value of an expression if it is a literal.
5559
If it is not a literal, returns None.
@@ -95,12 +99,9 @@ def get_var_name(x: BaseExpr) -> str | None:
9599
return None
96100

97101

98-
def get_callable_fn(x: T) -> Callable[..., T] | None:
102+
def get_callable_fn(x: T) -> Callable[..., T] | T | None:
99103
"""
100-
Gets the function of an expression if it is a call expression.
101-
If it is not a call expression (a property, a primitive value, constants, classvars, a let value), return None.
102-
For those values, you can check them by comparing them directly with equality or for primitives calling `.eval()`
103-
to return the Python value.
104+
Gets the function of an expression, or if it's a constant or classvar, return that.
104105
"""
105106
if not isinstance(x, RuntimeExpr):
106107
raise TypeError(f"Expected Expression, got {type(x).__name__}")
@@ -159,6 +160,7 @@ def _deconstruct_call_decl(
159160
"""
160161
args = call.args
161162
arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args)
163+
# TODO: handle values? Like constants
162164
if isinstance(call.callable, InitRef):
163165
return RuntimeClass(
164166
decls_thunk,

0 commit comments

Comments
 (0)