diff --git a/.gitignore b/.gitignore index 94e71e19..e2048df2 100644 --- a/.gitignore +++ b/.gitignore @@ -84,3 +84,4 @@ Source.* inlined visualizer.tgz package +.mypy_cache/ diff --git a/docs/changelog.md b/docs/changelog.md index 44417898..8d59cd3b 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,7 @@ _This project uses semantic versioning_ ## UNRELEASED +- Add support for `set_cost` action to have row level costs for extraction [#343](https://github.com/egraphs-good/egglog-python/pull/343) - Add `egraph.function_values(fn)` to export all function values like `print-function` [#340](https://github.com/egraphs-good/egglog-python/pull/340) - Add `egraph.stats()` method to print overall stats [#339](https://github.com/egraphs-good/egglog-python/pull/339) - Add `all_function_sizes` and `function_size` EGraph methods [#338](https://github.com/egraphs-good/egglog-python/pull/338) diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md index fcbb034f..5162233a 100644 --- a/docs/reference/egglog-translation.md +++ b/docs/reference/egglog-translation.md @@ -268,6 +268,23 @@ except BaseException as e: print(e) ``` +### Set Cost + +You can also set the cost of individual values, like the egglog experimental feature, to override the default cost from constructing a function: + +```{code-cell} python +# egg: (set-cost (fib 0) 1) +egraph.register(set_cost(fib(0), 1)) +``` + +This will be taken into account when extracting. Any value that can be converted to an `i64` is supported as a cost, +so dynamic costs can be created in rules. + +It does this by creating a new table for each function you set the cost for that maps the arguments to an i64. + +_Note: Unlike in egglog, where you have to declare which functions support custom costs, in Python all functions +are automatically registered to create a custom cost table when they are constructed_ + ## Defining Rules To define rules in Python, we create a rule with the `rule(*facts).then(*actions) (rule ...)` command in egglog. diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index 81049126..1e2f54cd 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -69,6 +69,7 @@ "SaturateDecl", "ScheduleDecl", "SequenceDecl", + "SetCostDecl", "SetDecl", "SpecialFunctions", "TypeOrVarRef", @@ -854,7 +855,14 @@ class PanicDecl: msg: str -ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl +@dataclass(frozen=True) +class SetCostDecl: + tp: JustTypeRef + expr: CallDecl + cost: ExprDecl + + +ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl | SetCostDecl ## diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index ac2a64b2..00f6eb59 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -40,7 +40,7 @@ from .version_compat import * if TYPE_CHECKING: - from .builtins import String, Unit + from .builtins import String, Unit, i64Like __all__ = [ @@ -84,6 +84,7 @@ "run", "seq", "set_", + "set_cost", "subsume", "union", "unstable_combine_rulesets", @@ -985,8 +986,14 @@ def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]: def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput: self._add_decls(expr) expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__) + # If we have defined any cost tables use the custom extraction + args = (expr, bindings.Lit(span(2), bindings.Int(n))) + if self._state.cost_callables: + cmd: bindings._Command = bindings.UserDefined(span(2), "extract", list(args)) + else: + cmd = bindings.Extract(span(2), *args) try: - return self._egraph.run_program(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))[0] + return self._egraph.run_program(cmd)[0] except BaseException as e: raise add_note("Extracting: " + str(expr), e) # noqa: B904 @@ -1460,10 +1467,13 @@ def __bool__(self) -> bool: """ Returns True if the two sides of an equality are structurally equal. """ - if not isinstance(self.fact, EqDecl): - msg = "Can only check equality facts" - raise TypeError(msg) - return self.fact.left == self.fact.right + match self.fact: + case EqDecl(_, left, right): + return left == right + case ExprFactDecl(TypedExprDecl(_, CallDecl(FunctionRef("!="), (left_tp, right_tp)))): + return left_tp != right_tp + msg = f"Can only check equality for == or != not {self}" + raise ValueError(msg) @dataclass @@ -1511,6 +1521,18 @@ def panic(message: str) -> Action: return Action(Declarations(), PanicDecl(message)) +def set_cost(expr: BaseExpr, cost: i64Like) -> Action: + """Set the cost of the given expression.""" + from .builtins import i64 # noqa: PLC0415 + + expr_runtime = to_runtime_expr(expr) + typed_expr_decl = expr_runtime.__egg_typed_expr__ + expr_decl = typed_expr_decl.expr + assert isinstance(expr_decl, CallDecl), "Can only set cost of calls, not literals or vars" + cost_decl = to_runtime_expr(convert(cost, i64)).__egg_typed_expr__.expr + return Action(expr_runtime.__egg_decls__, SetCostDecl(typed_expr_decl.tp, expr_decl, cost_decl)) + + def let(name: str, expr: BaseExpr) -> Action: """Create a let binding.""" runtime_expr = to_runtime_expr(expr) diff --git a/python/egglog/egraph_state.py b/python/egglog/egraph_state.py index a4acfc0b..52b0e291 100644 --- a/python/egglog/egraph_state.py +++ b/python/egglog/egraph_state.py @@ -6,7 +6,7 @@ import re from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from typing import TYPE_CHECKING, Literal, overload from typing_extensions import assert_never @@ -71,6 +71,9 @@ class EGraphState: # Cache of egg expressions for converting to egg expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict) + # Callables which have cost tables associated with them + cost_callables: set[CallableRef] = field(default_factory=set) + def copy(self) -> EGraphState: """ Returns a copy of the state. Th egraph reference is kept the same. Used for pushing/popping. @@ -83,6 +86,7 @@ def copy(self) -> EGraphState: callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(), type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(), expr_to_egg_cache=self.expr_to_egg_cache.copy(), + cost_callables=self.cost_callables.copy(), ) def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule: @@ -212,9 +216,32 @@ def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindin return bindings.Union(span(), self._expr_to_egg(lhs), self._expr_to_egg(rhs)) case PanicDecl(name): return bindings.Panic(span(), name) + case SetCostDecl(tp, expr, cost): + self.type_ref_to_egg(tp) + cost_table = self.create_cost_table(expr.callable) + args_egg = [self.typed_expr_to_egg(x, False) for x in expr.args] + return bindings.Set(span(), cost_table, args_egg, self._expr_to_egg(cost)) case _: assert_never(action) + def create_cost_table(self, ref: CallableRef) -> str: + """ + Creates the egg cost table if needed and gets the name of the table. + """ + name = self.cost_table_name(ref) + if ref not in self.cost_callables: + self.cost_callables.add(ref) + signature = self.__egg_decls__.get_callable_decl(ref).signature + assert isinstance(signature, FunctionSignature), "Can only add cost tables for functions" + signature = replace(signature, return_type=TypeRefWithVars("i64")) + self.egraph.run_program( + bindings.FunctionCommand(span(), name, self._signature_to_egg_schema(signature), None) + ) + return name + + def cost_table_name(self, ref: CallableRef) -> str: + return f"cost_table_{self.callable_ref_to_egg(ref)[0]}" + def fact_to_egg(self, fact: FactDecl) -> bindings._Fact: match fact: case EqDecl(tp, left, right): @@ -350,11 +377,16 @@ def op_mapping(self) -> dict[str, str]: """ Create a mapping of egglog function name to Python function name, for use in the serialized format for better visualization. + + Includes cost tables """ return { k: pretty_callable_ref(self.__egg_decls__, next(iter(v))) for k, v in self.egg_fn_to_callable_refs.items() if len(v) == 1 + } | { + self.cost_table_name(ref): f"cost({pretty_callable_ref(self.__egg_decls__, ref, include_all_args=True)})" + for ref in self.cost_callables } def possible_egglog_functions(self, names: list[str]) -> Iterable[str]: diff --git a/python/egglog/examples/jointree.py b/python/egglog/examples/jointree.py new file mode 100644 index 00000000..e596f95e --- /dev/null +++ b/python/egglog/examples/jointree.py @@ -0,0 +1,67 @@ +# mypy: disable-error-code="empty-body" + +""" +Join Tree (custom costs) +======================== + +Example of using custom cost functions for jointree. + +From https://egraphs.zulipchat.com/#narrow/stream/328972-general/topic/How.20can.20I.20find.20the.20tree.20associated.20with.20an.20extraction.3F +""" + +from __future__ import annotations + +from egglog import * + + +class JoinTree(Expr): + def __init__(self, name: StringLike) -> None: ... + + def join(self, other: JoinTree) -> JoinTree: ... + + @method(merge=lambda old, new: old.min(new)) # type:ignore[prop-decorator] + @property + def size(self) -> i64: ... + + +ra = JoinTree("a") +rb = JoinTree("b") +rc = JoinTree("c") +rd = JoinTree("d") +re = JoinTree("e") +rf = JoinTree("f") + +query = ra.join(rb).join(rc).join(rd).join(re).join(rf) + +egraph = EGraph() +egraph.register( + set_(ra.size).to(50), + set_(rb.size).to(200), + set_(rc.size).to(10), + set_(rd.size).to(123), + set_(re.size).to(10000), + set_(rf.size).to(1), +) + + +@egraph.register +def _rules(s: String, a: JoinTree, b: JoinTree, c: JoinTree, asize: i64, bsize: i64): + # cost of relation is its size minus 1, since the string arg will have a cost of 1 as well + yield rule(JoinTree(s).size == asize).then(set_cost(JoinTree(s), asize - 1)) + # cost/size of join is product of sizes + yield rule(a.join(b), a.size == asize, b.size == bsize).then( + set_(a.join(b).size).to(asize * bsize), set_cost(a.join(b), asize * bsize) + ) + # associativity + yield rewrite(a.join(b)).to(b.join(a)) + # commutativity + yield rewrite(a.join(b).join(c)).to(a.join(b.join(c))) + + +egraph.register(query) +egraph.run(1000) +print(egraph.extract(query)) +print(egraph.extract(query.size)) + + +egraph diff --git a/python/egglog/pretty.py b/python/egglog/pretty.py index 1ebbc586..8c8c7130 100644 --- a/python/egglog/pretty.py +++ b/python/egglog/pretty.py @@ -98,6 +98,7 @@ def pretty_callable_ref( ref: CallableRef, first_arg: ExprDecl | None = None, bound_tp_params: tuple[JustTypeRef, ...] | None = None, + include_all_args: bool = False, ) -> str: """ Pretty print a callable reference, using a dummy value for @@ -115,6 +116,13 @@ def pretty_callable_ref( # Either returns a function or a function with args. If args are provided, they would just be called, # on the function, so return them, because they are dummies if isinstance(res, tuple): + # If we want to include all args as ARG_STR, then we need to figure out how many to use + # used for set_cost so that `cost(E(...))` will show up as a call + if include_all_args: + signature = decls.get_callable_decl(ref).signature + assert isinstance(signature, FunctionSignature) + correct_args: list[ExprDecl] = [UnboundVarDecl(ARG_STR)] * len(signature.arg_types) + return f"{res[0]}({', '.join(context(a, parens=False, unwrap_lit=True) for a in correct_args)})" return res[0] return res @@ -190,6 +198,9 @@ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C90 pass case DefaultRewriteDecl(): pass + case SetCostDecl(_, e, c): + self(e) + self(c) case _: assert_never(decl) @@ -285,6 +296,8 @@ def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_na return f"{change}({self(expr)})", "action" case PanicDecl(s): return f"panic({s!r})", "action" + case SetCostDecl(_, expr, cost): + return f"set_cost({self(expr)}, {self(cost, unwrap_lit=True)})", "action" case EqDecl(_, left, right): return f"eq({self(left)}).to({self(right)})", "fact" case RulesetDecl(rules): diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 2771feaa..28cf61c9 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -1063,3 +1063,33 @@ def f(x: i64Like) -> i64: ... egraph.register(set_(f(i64(1))).to(i64(2))) values = egraph.function_values(f) assert values == {f(i64(1)): i64(2)} + + +def test_dynamic_cost(): + """ + https://github.com/egraphs-good/egglog-experimental/blob/6d07a34ac76deec751f86f70d9b9358cd3e236ca/tests/integration_test.rs#L5-L35 + """ + + class E(Expr): + def __init__(self, x: i64Like) -> None: ... + def __add__(self, other: E) -> E: ... + @method(cost=200) + def __sub__(self, other: E) -> E: ... + + egraph = EGraph() + egraph.register( + union(E(2)).with_(E(1) + E(1)), + set_cost(E(2), 1000), + set_cost(E(1), 100), + ) + assert egraph.extract(E(2), include_cost=True) == (E(1) + E(1), 203) + with egraph: + egraph.register(set_cost(E(1) + E(1), 800)) + assert egraph.extract(E(2), include_cost=True) == (E(2), 1001) + with egraph: + egraph.register(set_cost(E(1) + E(1), 798)) + assert egraph.extract(E(2), include_cost=True) == (E(1) + E(1), 1000) + egraph.register(union(E(2)).with_(E(5) - E(3))) + assert egraph.extract(E(2), include_cost=True) == (E(1) + E(1), 203) + egraph.register(set_cost(E(5) - E(3), 198)) + assert egraph.extract(E(2), include_cost=True) == (E(5) - E(3), 202)