Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,4 @@ Source.*
inlined
visualizer.tgz
package
.mypy_cache/
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ _This project uses semantic versioning_

## UNRELEASED

- Add support for `set_cost` action to have row level costs for extraction [#343](https://github.com/egraphs-good/egglog-python/pull/343)
- Add `egraph.function_values(fn)` to export all function values like `print-function` [#340](https://github.com/egraphs-good/egglog-python/pull/340)
- Add `egraph.stats()` method to print overall stats [#339](https://github.com/egraphs-good/egglog-python/pull/339)
- Add `all_function_sizes` and `function_size` EGraph methods [#338](https://github.com/egraphs-good/egglog-python/pull/338)
Expand Down
17 changes: 17 additions & 0 deletions docs/reference/egglog-translation.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,23 @@ except BaseException as e:
print(e)
```

### Set Cost

You can also set the cost of individual values, like the egglog experimental feature, to override the default cost from constructing a function:

```{code-cell} python
# egg: (set-cost (fib 0) 1)
egraph.register(set_cost(fib(0), 1))
```

This will be taken into account when extracting. Any value that can be converted to an `i64` is supported as a cost,
so dynamic costs can be created in rules.

It does this by creating a new table for each function you set the cost for that maps the arguments to an i64.

_Note: Unlike in egglog, where you have to declare which functions support custom costs, in Python all functions
are automatically regeistered to create a custom cost table when they are constructed_

## Defining Rules

To define rules in Python, we create a rule with the `rule(*facts).then(*actions) (rule ...)` command in egglog.
Expand Down
10 changes: 9 additions & 1 deletion python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"SaturateDecl",
"ScheduleDecl",
"SequenceDecl",
"SetCostDecl",
"SetDecl",
"SpecialFunctions",
"TypeOrVarRef",
Expand Down Expand Up @@ -854,7 +855,14 @@ class PanicDecl:
msg: str


ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl
@dataclass(frozen=True)
class SetCostDecl:
tp: JustTypeRef
expr: CallDecl
cost: ExprDecl


ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl | SetCostDecl


##
Expand Down
34 changes: 28 additions & 6 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .version_compat import *

if TYPE_CHECKING:
from .builtins import String, Unit
from .builtins import String, Unit, i64Like


__all__ = [
Expand Down Expand Up @@ -84,6 +84,7 @@
"run",
"seq",
"set_",
"set_cost",
"subsume",
"union",
"unstable_combine_rulesets",
Expand Down Expand Up @@ -985,8 +986,14 @@ def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]:
def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput:
self._add_decls(expr)
expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
# If we have defined any cost tables use the custom extraction
args = (expr, bindings.Lit(span(2), bindings.Int(n)))
if self._state.cost_callables:
cmd: bindings._Command = bindings.UserDefined(span(2), "extract", list(args))
else:
cmd = bindings.Extract(span(2), *args)
try:
return self._egraph.run_program(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))[0]
return self._egraph.run_program(cmd)[0]
except BaseException as e:
raise add_note("Extracting: " + str(expr), e) # noqa: B904

Expand Down Expand Up @@ -1460,10 +1467,13 @@ def __bool__(self) -> bool:
"""
Returns True if the two sides of an equality are structurally equal.
"""
if not isinstance(self.fact, EqDecl):
msg = "Can only check equality facts"
raise TypeError(msg)
return self.fact.left == self.fact.right
match self.fact:
case EqDecl(_, left, right):
return left == right
case ExprFactDecl(TypedExprDecl(_, CallDecl(FunctionRef("!="), (left_tp, right_tp)))):
return left_tp != right_tp
msg = f"Can only check equality for == or != not {self}"
raise ValueError(msg)


@dataclass
Expand Down Expand Up @@ -1511,6 +1521,18 @@ def panic(message: str) -> Action:
return Action(Declarations(), PanicDecl(message))


def set_cost(expr: BaseExpr, cost: i64Like) -> Action:
"""Set the cost of the given expression."""
from .builtins import i64 # noqa: PLC0415

expr_runtime = to_runtime_expr(expr)
typed_expr_decl = expr_runtime.__egg_typed_expr__
expr_decl = typed_expr_decl.expr
assert isinstance(expr_decl, CallDecl), "Can only set cost of calls, not literals or vars"
cost_decl = to_runtime_expr(convert(cost, i64)).__egg_typed_expr__.expr
return Action(expr_runtime.__egg_decls__, SetCostDecl(typed_expr_decl.tp, expr_decl, cost_decl))


def let(name: str, expr: BaseExpr) -> Action:
"""Create a let binding."""
runtime_expr = to_runtime_expr(expr)
Expand Down
34 changes: 33 additions & 1 deletion python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import re
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Literal, overload

from typing_extensions import assert_never
Expand Down Expand Up @@ -71,6 +71,9 @@ class EGraphState:
# Cache of egg expressions for converting to egg
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)

# Callables which have cost tables associated with them
cost_callables: set[CallableRef] = field(default_factory=set)

def copy(self) -> EGraphState:
"""
Returns a copy of the state. Th egraph reference is kept the same. Used for pushing/popping.
Expand All @@ -83,6 +86,7 @@ def copy(self) -> EGraphState:
callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(),
type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(),
expr_to_egg_cache=self.expr_to_egg_cache.copy(),
cost_callables=self.cost_callables.copy(),
)

def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
Expand Down Expand Up @@ -212,9 +216,32 @@ def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindin
return bindings.Union(span(), self._expr_to_egg(lhs), self._expr_to_egg(rhs))
case PanicDecl(name):
return bindings.Panic(span(), name)
case SetCostDecl(tp, expr, cost):
self.type_ref_to_egg(tp)
cost_table = self.create_cost_table(expr.callable)
args_egg = [self.typed_expr_to_egg(x, False) for x in expr.args]
return bindings.Set(span(), cost_table, args_egg, self._expr_to_egg(cost))
case _:
assert_never(action)

def create_cost_table(self, ref: CallableRef) -> str:
"""
Creates the egg cost table if needed and gets the name of the table.
"""
name = self.cost_table_name(ref)
if ref not in self.cost_callables:
self.cost_callables.add(ref)
signature = self.__egg_decls__.get_callable_decl(ref).signature
assert isinstance(signature, FunctionSignature), "Can only add cost tables for functions"
signature = replace(signature, return_type=TypeRefWithVars("i64"))
self.egraph.run_program(
bindings.FunctionCommand(span(), name, self._signature_to_egg_schema(signature), None)
)
return name

def cost_table_name(self, ref: CallableRef) -> str:
return f"cost_table_{self.callable_ref_to_egg(ref)[0]}"

def fact_to_egg(self, fact: FactDecl) -> bindings._Fact:
match fact:
case EqDecl(tp, left, right):
Expand Down Expand Up @@ -350,11 +377,16 @@ def op_mapping(self) -> dict[str, str]:
"""
Create a mapping of egglog function name to Python function name, for use in the serialized format
for better visualization.

Includes cost tables
"""
return {
k: pretty_callable_ref(self.__egg_decls__, next(iter(v)))
for k, v in self.egg_fn_to_callable_refs.items()
if len(v) == 1
} | {
self.cost_table_name(ref): f"cost({pretty_callable_ref(self.__egg_decls__, ref, include_all_args=True)})"
for ref in self.cost_callables
}

def possible_egglog_functions(self, names: list[str]) -> Iterable[str]:
Expand Down
67 changes: 67 additions & 0 deletions python/egglog/examples/jointree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# mypy: disable-error-code="empty-body"

"""
Join Tree (custom costs)
========================

Example of using custom cost functions for jointree.

From https://egraphs.zulipchat.com/#narrow/stream/328972-general/topic/How.20can.20I.20find.20the.20tree.20associated.20with.20an.20extraction.3F
"""

from __future__ import annotations

from egglog import *


class JoinTree(Expr):
def __init__(self, name: StringLike) -> None: ...

def join(self, other: JoinTree) -> JoinTree: ...

@method(merge=lambda old, new: old.min(new))
@property
def size(self) -> i64: ...


ra = JoinTree("a")
rb = JoinTree("b")
rc = JoinTree("c")
rd = JoinTree("d")
re = JoinTree("e")
rf = JoinTree("f")

query = ra.join(rb).join(rc).join(rd).join(re).join(rf)

egraph = EGraph()
egraph.register(
set_(ra.size).to(50),
set_(rb.size).to(200),
set_(rc.size).to(10),
set_(rd.size).to(123),
set_(re.size).to(10000),
set_(rf.size).to(1),
)


@egraph.register
def _rules(s: String, a: JoinTree, b: JoinTree, c: JoinTree, asize: i64, bsize: i64):
# cost of relation is its size minus 1, since the string arg will have a cost of 1 as well
yield rule(JoinTree(s).size == asize).then(set_cost(JoinTree(s), asize - 1))
# cost/size of join is product of sizes
yield rule(a.join(b), a.size == asize, b.size == bsize).then(
set_(a.join(b).size).to(asize * bsize), set_cost(a.join(b), asize * bsize)
)
# associativity
yield rewrite(a.join(b)).to(b.join(a))
# commutativity
yield rewrite(a.join(b).join(c)).to(a.join(b.join(c)))


egraph.register(query)
egraph.run(1000)
print(egraph.extract(query))
print(egraph.extract(query.size))


egraph
13 changes: 13 additions & 0 deletions python/egglog/pretty.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def pretty_callable_ref(
ref: CallableRef,
first_arg: ExprDecl | None = None,
bound_tp_params: tuple[JustTypeRef, ...] | None = None,
include_all_args: bool = False,
) -> str:
"""
Pretty print a callable reference, using a dummy value for
Expand All @@ -115,6 +116,13 @@ def pretty_callable_ref(
# Either returns a function or a function with args. If args are provided, they would just be called,
# on the function, so return them, because they are dummies
if isinstance(res, tuple):
# If we want to include all args as ARG_STR, then we need to figure out how many to use
# used for set_cost so that `cost(E(...))` will show up as a call
if include_all_args:
signature = decls.get_callable_decl(ref).signature
assert isinstance(signature, FunctionSignature)
args: list[ExprDecl] = [UnboundVarDecl(ARG_STR)] * len(signature.arg_types)
return f"{res[0]}({', '.join(context(a, parens=False, unwrap_lit=True) for a in args)})"
return res[0]
return res

Expand Down Expand Up @@ -190,6 +198,9 @@ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C90
pass
case DefaultRewriteDecl():
pass
case SetCostDecl(_, e, c):
self(e)
self(c)
case _:
assert_never(decl)

Expand Down Expand Up @@ -285,6 +296,8 @@ def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_na
return f"{change}({self(expr)})", "action"
case PanicDecl(s):
return f"panic({s!r})", "action"
case SetCostDecl(_, expr, cost):
return f"set_cost({self(expr)}, {self(cost, unwrap_lit=True)})", "action"
case EqDecl(_, left, right):
return f"eq({self(left)}).to({self(right)})", "fact"
case RulesetDecl(rules):
Expand Down
30 changes: 30 additions & 0 deletions python/tests/test_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,3 +1063,33 @@ def f(x: i64Like) -> i64: ...
egraph.register(set_(f(i64(1))).to(i64(2)))
values = egraph.function_values(f)
assert values == {f(i64(1)): i64(2)}


def test_dynamic_cost():
"""
https://github.com/egraphs-good/egglog-experimental/blob/6d07a34ac76deec751f86f70d9b9358cd3e236ca/tests/integration_test.rs#L5-L35
"""

class E(Expr):
def __init__(self, x: i64Like) -> None: ...
def __add__(self, other: E) -> E: ...
@method(cost=200)
def __sub__(self, other: E) -> E: ...

egraph = EGraph()
egraph.register(
union(E(2)).with_(E(1) + E(1)),
set_cost(E(2), 1000),
set_cost(E(1), 100),
)
assert egraph.extract(E(2), include_cost=True) == (E(1) + E(1), 203)
with egraph:
egraph.register(set_cost(E(1) + E(1), 800))
assert egraph.extract(E(2), include_cost=True) == (E(2), 1001)
with egraph:
egraph.register(set_cost(E(1) + E(1), 798))
assert egraph.extract(E(2), include_cost=True) == (E(1) + E(1), 1000)
egraph.register(union(E(2)).with_(E(5) - E(3)))
assert egraph.extract(E(2), include_cost=True) == (E(1) + E(1), 203)
egraph.register(set_cost(E(5) - E(3), 198))
assert egraph.extract(E(2), include_cost=True) == (E(5) - E(3), 202)