Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,4 @@ Source.*
inlined
visualizer.tgz
package
.mypy_cache/
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ _This project uses semantic versioning_

## UNRELEASED

- Add support for `set_cost` action to have row level costs for extraction [#343](https://github.com/egraphs-good/egglog-python/pull/343)
- Add `egraph.function_values(fn)` to export all function values like `print-function` [#340](https://github.com/egraphs-good/egglog-python/pull/340)
- Add `egraph.stats()` method to print overall stats [#339](https://github.com/egraphs-good/egglog-python/pull/339)
- Add `all_function_sizes` and `function_size` EGraph methods [#338](https://github.com/egraphs-good/egglog-python/pull/338)
Expand Down
17 changes: 17 additions & 0 deletions docs/reference/egglog-translation.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,23 @@ except BaseException as e:
print(e)
```

### Set Cost

You can also set the cost of individual values, like the egglog experimental feature, to override the default cost from constructing a function:

```{code-cell} python
# egg: (set-cost (fib 0) 1)
egraph.register(set_cost(fib(0), 1))
```

This will be taken into account when extracting. Any value that can be converted to an `i64` is supported as a cost,
so dynamic costs can be created in rules.

It does this by creating a new table for each function you set the cost for that maps the arguments to an i64.

_Note: Unlike in egglog, where you have to declare which functions support custom costs, in Python all functions
are automatically registered to create a custom cost table when they are constructed_

## Defining Rules

To define rules in Python, we create a rule with the `rule(*facts).then(*actions) (rule ...)` command in egglog.
Expand Down
10 changes: 9 additions & 1 deletion python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"SaturateDecl",
"ScheduleDecl",
"SequenceDecl",
"SetCostDecl",
"SetDecl",
"SpecialFunctions",
"TypeOrVarRef",
Expand Down Expand Up @@ -854,7 +855,14 @@ class PanicDecl:
msg: str


ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl
@dataclass(frozen=True)
class SetCostDecl:
tp: JustTypeRef
expr: CallDecl
cost: ExprDecl


ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl | SetCostDecl


##
Expand Down
34 changes: 28 additions & 6 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from .version_compat import *

if TYPE_CHECKING:
from .builtins import String, Unit
from .builtins import String, Unit, i64Like


__all__ = [
Expand Down Expand Up @@ -84,6 +84,7 @@
"run",
"seq",
"set_",
"set_cost",
"subsume",
"union",
"unstable_combine_rulesets",
Expand Down Expand Up @@ -985,8 +986,14 @@ def extract_multiple(self, expr: BASE_EXPR, n: int) -> list[BASE_EXPR]:
def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._CommandOutput:
self._add_decls(expr)
expr = self._state.typed_expr_to_egg(expr.__egg_typed_expr__)
# If we have defined any cost tables use the custom extraction
args = (expr, bindings.Lit(span(2), bindings.Int(n)))
if self._state.cost_callables:
cmd: bindings._Command = bindings.UserDefined(span(2), "extract", list(args))
else:
cmd = bindings.Extract(span(2), *args)
try:
return self._egraph.run_program(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))[0]
return self._egraph.run_program(cmd)[0]
except BaseException as e:
raise add_note("Extracting: " + str(expr), e) # noqa: B904

Expand Down Expand Up @@ -1460,10 +1467,13 @@ def __bool__(self) -> bool:
"""
Returns True if the two sides of an equality are structurally equal.
"""
if not isinstance(self.fact, EqDecl):
msg = "Can only check equality facts"
raise TypeError(msg)
return self.fact.left == self.fact.right
match self.fact:
case EqDecl(_, left, right):
return left == right
case ExprFactDecl(TypedExprDecl(_, CallDecl(FunctionRef("!="), (left_tp, right_tp)))):
return left_tp != right_tp
msg = f"Can only check equality for == or != not {self}"
raise ValueError(msg)


@dataclass
Expand Down Expand Up @@ -1511,6 +1521,18 @@ def panic(message: str) -> Action:
return Action(Declarations(), PanicDecl(message))


def set_cost(expr: BaseExpr, cost: i64Like) -> Action:
"""Set the cost of the given expression."""
from .builtins import i64 # noqa: PLC0415

expr_runtime = to_runtime_expr(expr)
typed_expr_decl = expr_runtime.__egg_typed_expr__
expr_decl = typed_expr_decl.expr
assert isinstance(expr_decl, CallDecl), "Can only set cost of calls, not literals or vars"
cost_decl = to_runtime_expr(convert(cost, i64)).__egg_typed_expr__.expr
return Action(expr_runtime.__egg_decls__, SetCostDecl(typed_expr_decl.tp, expr_decl, cost_decl))


def let(name: str, expr: BaseExpr) -> Action:
"""Create a let binding."""
runtime_expr = to_runtime_expr(expr)
Expand Down
34 changes: 33 additions & 1 deletion python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import re
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Literal, overload

from typing_extensions import assert_never
Expand Down Expand Up @@ -71,6 +71,9 @@ class EGraphState:
# Cache of egg expressions for converting to egg
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)

# Callables which have cost tables associated with them
cost_callables: set[CallableRef] = field(default_factory=set)

def copy(self) -> EGraphState:
"""
Returns a copy of the state. Th egraph reference is kept the same. Used for pushing/popping.
Expand All @@ -83,6 +86,7 @@ def copy(self) -> EGraphState:
callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(),
type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(),
expr_to_egg_cache=self.expr_to_egg_cache.copy(),
cost_callables=self.cost_callables.copy(),
)

def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
Expand Down Expand Up @@ -212,9 +216,32 @@ def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindin
return bindings.Union(span(), self._expr_to_egg(lhs), self._expr_to_egg(rhs))
case PanicDecl(name):
return bindings.Panic(span(), name)
case SetCostDecl(tp, expr, cost):
self.type_ref_to_egg(tp)
cost_table = self.create_cost_table(expr.callable)
args_egg = [self.typed_expr_to_egg(x, False) for x in expr.args]
return bindings.Set(span(), cost_table, args_egg, self._expr_to_egg(cost))
case _:
assert_never(action)

def create_cost_table(self, ref: CallableRef) -> str:
"""
Creates the egg cost table if needed and gets the name of the table.
"""
name = self.cost_table_name(ref)
if ref not in self.cost_callables:
self.cost_callables.add(ref)
signature = self.__egg_decls__.get_callable_decl(ref).signature
assert isinstance(signature, FunctionSignature), "Can only add cost tables for functions"
signature = replace(signature, return_type=TypeRefWithVars("i64"))
self.egraph.run_program(
bindings.FunctionCommand(span(), name, self._signature_to_egg_schema(signature), None)
)
return name

def cost_table_name(self, ref: CallableRef) -> str:
return f"cost_table_{self.callable_ref_to_egg(ref)[0]}"

def fact_to_egg(self, fact: FactDecl) -> bindings._Fact:
match fact:
case EqDecl(tp, left, right):
Expand Down Expand Up @@ -350,11 +377,16 @@ def op_mapping(self) -> dict[str, str]:
"""
Create a mapping of egglog function name to Python function name, for use in the serialized format
for better visualization.

Includes cost tables
"""
return {
k: pretty_callable_ref(self.__egg_decls__, next(iter(v)))
for k, v in self.egg_fn_to_callable_refs.items()
if len(v) == 1
} | {
self.cost_table_name(ref): f"cost({pretty_callable_ref(self.__egg_decls__, ref, include_all_args=True)})"
for ref in self.cost_callables
}

def possible_egglog_functions(self, names: list[str]) -> Iterable[str]:
Expand Down
67 changes: 67 additions & 0 deletions python/egglog/examples/jointree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# mypy: disable-error-code="empty-body"

"""
Join Tree (custom costs)
========================

Example of using custom cost functions for jointree.

From https://egraphs.zulipchat.com/#narrow/stream/328972-general/topic/How.20can.20I.20find.20the.20tree.20associated.20with.20an.20extraction.3F
"""

from __future__ import annotations

from egglog import *


class JoinTree(Expr):
def __init__(self, name: StringLike) -> None: ...

def join(self, other: JoinTree) -> JoinTree: ...

@method(merge=lambda old, new: old.min(new)) # type:ignore[prop-decorator]
@property
def size(self) -> i64: ...


ra = JoinTree("a")
rb = JoinTree("b")
rc = JoinTree("c")
rd = JoinTree("d")
re = JoinTree("e")
rf = JoinTree("f")

query = ra.join(rb).join(rc).join(rd).join(re).join(rf)

egraph = EGraph()
egraph.register(
set_(ra.size).to(50),
set_(rb.size).to(200),
set_(rc.size).to(10),
set_(rd.size).to(123),
set_(re.size).to(10000),
set_(rf.size).to(1),
)


@egraph.register
def _rules(s: String, a: JoinTree, b: JoinTree, c: JoinTree, asize: i64, bsize: i64):
# cost of relation is its size minus 1, since the string arg will have a cost of 1 as well
yield rule(JoinTree(s).size == asize).then(set_cost(JoinTree(s), asize - 1))
# cost/size of join is product of sizes
yield rule(a.join(b), a.size == asize, b.size == bsize).then(
set_(a.join(b).size).to(asize * bsize), set_cost(a.join(b), asize * bsize)
)
# associativity
yield rewrite(a.join(b)).to(b.join(a))
# commutativity
yield rewrite(a.join(b).join(c)).to(a.join(b.join(c)))


egraph.register(query)
egraph.run(1000)
print(egraph.extract(query))
print(egraph.extract(query.size))


egraph
13 changes: 13 additions & 0 deletions python/egglog/pretty.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def pretty_callable_ref(
ref: CallableRef,
first_arg: ExprDecl | None = None,
bound_tp_params: tuple[JustTypeRef, ...] | None = None,
include_all_args: bool = False,
) -> str:
"""
Pretty print a callable reference, using a dummy value for
Expand All @@ -115,6 +116,13 @@ def pretty_callable_ref(
# Either returns a function or a function with args. If args are provided, they would just be called,
# on the function, so return them, because they are dummies
if isinstance(res, tuple):
# If we want to include all args as ARG_STR, then we need to figure out how many to use
# used for set_cost so that `cost(E(...))` will show up as a call
if include_all_args:
signature = decls.get_callable_decl(ref).signature
assert isinstance(signature, FunctionSignature)
correct_args: list[ExprDecl] = [UnboundVarDecl(ARG_STR)] * len(signature.arg_types)
return f"{res[0]}({', '.join(context(a, parens=False, unwrap_lit=True) for a in correct_args)})"
return res[0]
return res

Expand Down Expand Up @@ -190,6 +198,9 @@ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C90
pass
case DefaultRewriteDecl():
pass
case SetCostDecl(_, e, c):
self(e)
self(c)
case _:
assert_never(decl)

Expand Down Expand Up @@ -285,6 +296,8 @@ def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_na
return f"{change}({self(expr)})", "action"
case PanicDecl(s):
return f"panic({s!r})", "action"
case SetCostDecl(_, expr, cost):
return f"set_cost({self(expr)}, {self(cost, unwrap_lit=True)})", "action"
case EqDecl(_, left, right):
return f"eq({self(left)}).to({self(right)})", "fact"
case RulesetDecl(rules):
Expand Down
30 changes: 30 additions & 0 deletions python/tests/test_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,3 +1063,33 @@ def f(x: i64Like) -> i64: ...
egraph.register(set_(f(i64(1))).to(i64(2)))
values = egraph.function_values(f)
assert values == {f(i64(1)): i64(2)}


def test_dynamic_cost():
"""
https://github.com/egraphs-good/egglog-experimental/blob/6d07a34ac76deec751f86f70d9b9358cd3e236ca/tests/integration_test.rs#L5-L35
"""

class E(Expr):
def __init__(self, x: i64Like) -> None: ...
def __add__(self, other: E) -> E: ...
@method(cost=200)
def __sub__(self, other: E) -> E: ...

egraph = EGraph()
egraph.register(
union(E(2)).with_(E(1) + E(1)),
set_cost(E(2), 1000),
set_cost(E(1), 100),
)
assert egraph.extract(E(2), include_cost=True) == (E(1) + E(1), 203)
with egraph:
egraph.register(set_cost(E(1) + E(1), 800))
assert egraph.extract(E(2), include_cost=True) == (E(2), 1001)
with egraph:
egraph.register(set_cost(E(1) + E(1), 798))
assert egraph.extract(E(2), include_cost=True) == (E(1) + E(1), 1000)
egraph.register(union(E(2)).with_(E(5) - E(3)))
assert egraph.extract(E(2), include_cost=True) == (E(1) + E(1), 203)
egraph.register(set_cost(E(5) - E(3), 198))
assert egraph.extract(E(2), include_cost=True) == (E(5) - E(3), 202)
Loading