diff --git a/docs/changelog.md b/docs/changelog.md index d1de268b..ce492a8c 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,7 @@ _This project uses semantic versioning_ ## UNRELEASED +- Add egglog tutorials, change display to not inline by default, and fix bug looking up binary methods [#352](https://github.com/egraphs-good/egglog-python/pull/352) - Add `back_off` scheduler [#350](https://github.com/egraphs-good/egglog-python/pull/350) ## 11.2.0 (2025-09-03) diff --git a/docs/conf.py b/docs/conf.py index 028b7878..e0c8aed2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,15 +13,15 @@ post_auto_image = 1 -html_sidebars = { - "**": [ - "sidebar-nav-bs", - "ablog/postcard.html", - "ablog/recentposts.html", - "ablog/tagcloud.html", - "ablog/categories.html", - ] -} +# html_sidebars = { +# "**": [ +# "sidebar-nav-bs", +# "ablog/postcard.html", +# "ablog/recentposts.html", +# "ablog/tagcloud.html", +# "ablog/categories.html", +# ] +# } ## # Myst ## @@ -146,8 +146,8 @@ # myst_nb default settings # Custom formats for reading notebook; suffix -> reader -# nb_custom_formats = {} - +# https://github.com/mwouts/jupytext/blob/main/docs/formats-scripts.md#the-light-format +nb_custom_formats = {".py": ["jupytext.reads", {"fmt": "py:light"}]} # Notebook level metadata key for config overrides # nb_metadata_key = 'mystnb' diff --git a/docs/tutorials.md b/docs/tutorials.md index 71add8df..66feef42 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -4,4 +4,9 @@ auto_examples/index tutorials/getting-started tutorials/sklearn +tutorials/tut_1_basics +tutorials/tut_2_datalog +tutorials/tut_3_analysis +tutorials/tut_4_scheduling +tutorials/tut_5_extraction ``` diff --git a/docs/tutorials/__init__.py b/docs/tutorials/__init__.py new file mode 100644 index 00000000..6c102dfa --- /dev/null +++ b/docs/tutorials/__init__.py @@ -0,0 +1,3 @@ +""" +Tutorials +""" diff --git a/docs/tutorials/tut_1_basics.py b/docs/tutorials/tut_1_basics.py new file mode 100644 index 00000000..33731594 --- /dev/null +++ b/docs/tutorials/tut_1_basics.py @@ -0,0 +1,183 @@ +# # 01 - Basics of Equality Saturation +# +# _[This tutorial is translated from egglog.](https://egraphs-good.github.io/egglog-tutorial/01-basics.html)_ +# +# In this tutorial, we will build an optimizer for a subset of linear algebra using egglog. +# We will start by optimizing simple integer arithmetic expressions. +# Our initial DSL supports constants, variables, addition, and multiplication. + +# + +# mypy: disable-error-code="empty-body" +from __future__ import annotations +from typing import TypeAlias +from collections.abc import Iterable +from egglog import * + + +class Num(Expr): + def __init__(self, value: i64Like) -> None: ... + + @classmethod + def var(cls, name: StringLike) -> Num: ... + + def __add__(self, other: NumLike) -> Num: ... + def __mul__(self, other: NumLike) -> Num: ... + + # Support inverse operations for convenience + # they will be translated to non-reversed ones + def __radd__(self, other: NumLike) -> Num: ... + def __rmul__(self, other: NumLike) -> Num: ... + + +NumLike: TypeAlias = Num | StringLike | i64Like +# - + + +# The signature here takes `NumLike` not `Num` so that you can write `Num(1) + 2` instead of +# `Num(1) + Num(2)`. This is helpful for ease of use and also for compatibility when you are trying to +# create expressions that act like Python objects which perform upcasting. +# +# To support this, you must define conversions between primitive types and your expression types. +# When a value is passed into a function, it will find the type it should be converted to and +# transitively apply the conversions you have defined: + +converter(i64, Num, Num) +converter(String, Num, Num.var) + +# Now, let's define some simple expressions. + +egraph = EGraph() +x = Num.var("x") +expr1 = egraph.let("expr1", 2 * (x * 3)) +expr2 = egraph.let("expr2", 6 * x) + +# You should see an e-graph with two expressions. + +egraph + +# We can `.extract` the values of the expressions as well to see their fully expanded forms. + +egraph.extract(String("Hello, world!")) + +egraph.extract(i64(42)) + +egraph.extract(expr1) + +egraph.extract(expr2) + +# We can use the `check` commands to check properties of our e-graph. + +x, y = vars_("x y", Num) +egraph.check(expr1 == x * y) + +# This checks if `expr1` is equivalent to some expression `x * y`, where `x` and `y` are +# variables that can be mapped to any `Num` expression in the e-graph. +# +# Checks can fail. For example the following check fails because `expr1` is not equivalent to +# `x + y` for any `x` and `y` in the e-graph. + +egraph.check_fail(expr1 == x + y) + +# Let us define some rewrite rules over our small DSL. + + +@egraph.register +def _add_comm(x: Num, y: Num): + yield rewrite(x + y).to(y + x) + + +# This could also been written like: +# +# ```python +# x, y = vars_("x y", Num) +# egraph.register(rewrite(x + y).to(y + x)) +# ``` +# +# In this tutorial we will use the function form to define rewrites and rules, because then then we only +# have to write the variable names once as arguments and they are not leaked to the outer scope. + + +# This rule asserts that addition is commutative. More concretely, this rules says, if the e-graph +# contains expressions of the form `x + y`, then the e-graph should also contain the +# expression `y + x`, and they should be equivalent. +# +# Similarly, we can define the associativity rule for addition. + + +@egraph.register +def _add_assoc(x: Num, y: Num, z: Num) -> Iterable[RewriteOrRule]: + yield rewrite(x + (y + z)).to((x + y) + z) + + +# This rule says, if the e-graph contains expressions of the form `x + (y + z)`, then the e-graph should also contain +# the expression `(x + y) + z`, and they should be equivalent. + +# There are two subtleties to rules: +# +# 1. Defining a rule is different from running it. The following check would fail at this point +# because the commutativity rule has not been run (we've inserted `x + 3` but not yet derived `3 + x`). + +egraph.check_fail((x + 3) == (3 + x)) + +# 2. Rules are not instantiated for every possible term; they are only instantiated for terms that are +# in the e-graph. For instance, even if we ran the commutativity rule above, the following check would +# still fail because the e-graph does not contain either of the terms `Num(-2) + Num(2)` or `Num(2) + Num(-2)`. + +egraph.check_fail(Num(-2) + 2 == Num(2) + -2) + +# Let's also define commutativity and associativity for multiplication. + + +@egraph.register +def _mul(x: Num, y: Num, z: Num) -> Iterable[RewriteOrRule]: + yield rewrite(x * y).to(y * x) + yield rewrite(x * (y * z)).to((x * y) * z) + + +# `egglog` also defines a set of built-in functions over primitive types, such as `+` and `*`, +# and supports operator overloading, so the same operator can be used with different types. + +egraph.extract(i64(1) + 2) + +egraph.extract(String("1") + "2") + +egraph.extract(f64(1.0) + 2.0) + +# With primitives, we can define rewrite rules that talk about the semantics of operators. +# The following rules show constant folding over addition and multiplication. + + +@egraph.register +def _const_fold(a: i64, b: i64) -> Iterable[RewriteOrRule]: + yield rewrite(Num(a) + Num(b)).to(Num(a + b)) + yield rewrite(Num(a) * Num(b)).to(Num(a * b)) + + +# While we have defined several rules, the e-graph has not changed since we inserted the two +# expressions. To run rules we have defined so far, we can use `run`. + +egraph.run(10) + +# This tells `egglog` to run our rules for 10 iterations. More precisely, egglog runs the +# following pseudo code: +# +# ``` +# for i in range(10): +# for r in rules: +# ms = r.find_matches(egraph) +# for m in ms: +# egraph = egraph.apply_rule(r, m) +# egraph = rebuild(egraph) +# ``` +# +# In other words, `egglog` computes all the matches for one iteration before making any +# updates to the e-graph. This is in contrast to an evaluation model where rules are immediately +# applied and the matches are obtained on demand over a changing e-graph. +# +# We can now look at the e-graph and see that that `2 * (x + 3)` and `6 + (2 * x)` are now in the same E-class. + +egraph + +# We can also check this fact explicitly + +egraph.check(expr1 == expr2) diff --git a/docs/tutorials/tut_2_datalog.py b/docs/tutorials/tut_2_datalog.py new file mode 100644 index 00000000..f1c479dd --- /dev/null +++ b/docs/tutorials/tut_2_datalog.py @@ -0,0 +1,253 @@ +# # 02 - Datalog +# +# _[This tutorial is translated from egglog.](https://egraphs-good.github.io/egglog-tutorial/02-datalog.html)_ +# +# Datalog is a relational language for deductive reasoning. In the last lesson, we write our first +# equality saturation program in egglog, but you can also write rules for deductive reasoning a la Datalog. +# In this lesson, we will write several classic Datalog programs in egglog. One of the benifits +# of egglog being a language for program optimization is that it can talk about terms natively, +# so in egglog we get Datalog with terms for free. + + +# mypy: disable-error-code="empty-body" +from __future__ import annotations +from collections.abc import Iterable +from typing import TypeAlias +from egglog import * + +# Let's first define relations edge and path. +# We use `edge(a, b)` to mean the tuple (a, b) is in the `edge` relation. +# `edge(a, b)` means there are directed edges from `a` to `b`, +# and we will use it to compute the `path` relation, +# where `path(a, b)` means there is a (directed) path from `a` to `b`. + +edge = relation("edge", i64, i64) +path = relation("path", i64, i64) + +# We can insert edges into our relation by asserting facts: + +egraph = EGraph() +egraph.register( + edge(i64(1), i64(2)), + edge(i64(2), i64(3)), + edge(i64(3), i64(4)), +) + +# Fact definitions are similar to definitions using `egraph.let` definitions in the last lesson, +# in that facts are immediately added to relations. +# +# Now let's tell egglog how to derive the `path` relation. +# +# First, if an edge from a to b exists, then it is already a proof +# that there exists a path from a to b. + + +@egraph.register +def _(a: i64, b: i64) -> Iterable[RewriteOrRule]: + yield rule(edge(a, b)).then(path(a, b)) + + +# A rule has the form `rule(atom1, atom2 ..).then(action1 action2 ..)`. +# +# For the rule we have just defined, the only atom is `path(a, b)`, which asks egglog to search +# for possible `a` and `b`s such that `path(a, b)` is a fact in the database. +# +# We call the first part the "query" of a rule, and the second part the "body" of a rule. +# In Datalog terminology, confusingly, the first part is called the "body" of the rule +# while the second part is called the "head" of the rule. This is because Datalog rules +# are usually written as `head :- body`. To avoid confusion, we will refrain from using +# Datalog terminology. +# +# The rule above defines the base case of the path relation. The inductive case reads as follows: +# if we know there is a path from `a` to `b`, and there is an edge from `b` to `c`, then +# there is also a path from `a` to `c`. +# This can be expressed as egglog rule below: + + +@egraph.register +def _(a: i64, b: i64, c: i64) -> Iterable[RewriteOrRule]: + yield rule(path(a, b), edge(b, c)).then(path(a, c)) + + +# Again, defining a rule does not mean running it in egglog, which may be a surprise to those familiar with Datalog. +# The user still needs to run the program. +# For instance, the following check would fail at this point. + +egraph.check_fail(path(i64(1), i64(4))) + +# But it passes after we run our rules for 10 iterations. + +egraph.run(10) +egraph.check(path(i64(1), i64(4))) + +# For many deductive rules, we do not know the number of iterations +# needed to reach a fixpoint. The egglog language provides the `saturate` scheduling primitive to run the rules until fixpoint. + +egraph.run(run().saturate()) +egraph + +# We will cover more details about schedules later in the tutorial. + + +# Our last example determines whether there is a path from one node to another, +# but we don't know the details about the path. +# Let's slightly extend our program to obtain the length of the shortest path between any two nodes. + + +# + +@function +def edge_len(from_: i64Like, to: i64Like) -> i64: ... + + +@function(merge=lambda old, new: old.min(new)) +def path_len(from_: i64Like, to: i64Like) -> i64: ... + + +# - + +# Here, we use a new decorator called `function` to define a table that respects the functional dependency. +# A relation is just a function with output domain `Unit`. +# By defining `edge_len` and `path_len` with `function`, we can associate a length to each path. +# +# What happens it the same tuple of a function is mapped to two values? +# In the case of relation, this is easy: `Unit` only has one value, so the two values must be identical. +# But in general, that would be a violation of functional dependency, the property that `a = b` implies `f(a) = f(b)`. +# Egglog allows us to specify how to reconcile two values that are mapped from the same tuple using _merge expressions_. +# +# For instance, for `path`, the merge expression is `old.min(new)`. The merge function is passed two special values +# `old` and `new` that denotes the current output of the tuple and the output of the new, to-be-inserted value. +# The merge expression for `path` says that, when there are two paths from `a` to `b` with lengths `old` and `new`, +# we keep the shorter one, i.e., `old.min(new)`. +# +# For `edge_len`, we can define the merge expression the same as `path_len`, which means that we only keep the shortest edge +# if there are multiple edges. But we can also assert that `edge_len` does not have a merge expression, +# which is the default if none is provided. +# This means we don't expect there will be multiple edges between two nodes. More generally, it is the user's +# responsibility to ensure no tuples with conflicting output values exist. If a conflict happens, egglog will +# raise an error. + +# Now let's insert the same edges as before, but we will assign a length to each edge. This is done using the `set` action, +# which takes a tuple and an output value: + +egraph = EGraph() +egraph.register( + set_(edge_len(1, 2)).to(i64(10)), + set_(edge_len(2, 3)).to(i64(10)), + set_(edge_len(1, 3)).to(i64(30)), +) + +# Let us define the reflexive rule and transitive rule for the `path` function. +# In this rule, we use the `set` action to set the output value of the `path` function. +# On the query side, we use `a == b` (or `eq(a).to(b)` if equality is overloaded) to bind the output value of a function. + + +@egraph.register +def _(a: i64, b: i64, c: i64, ab: i64, bc: i64) -> Iterable[RewriteOrRule]: + yield rule(edge_len(a, b) == ab).then(set_(path_len(a, b)).to(ab)) + yield rule(path_len(a, b) == ab, edge_len(b, c) == bc).then(set_(path_len(a, c)).to(ab + bc)) + + +# Let's run our rules and check we get the desired shortest path + +egraph.run(run().saturate()) +egraph.check(path_len(1, 3) == 20) +egraph + + +# Now let us combine the knowledge we have learned in lessons 1 and 2 to write a program that combines +# both equality saturation and Datalog. +# +# We reuse our path example, but this time the nodes are terms constructed using the `Node` constructor, +# +# We start by defining a new, union-able sort (created by sub-classing expression) with a new constructor + + +# + +class Node(Expr): + def __init__(self, x: i64Like) -> None: ... + def edge(self, other: NodeLike) -> Unit: ... + def path(self, other: NodeLike) -> Unit: ... + + +NodeLike: TypeAlias = Node | i64Like + +converter(i64, Node, Node) +# - + +# Note: We could have equivalently written +# +# ```python +# class Node(Sort): ... +# +# @function +# def mk(x: i64Like) -> Node: ... +# edge_node = relation("edge_node", Node, Node) +# path_node = relation("path_node", Node, Node) +# ``` +# +# All methods of classes are syntactic sugar for creating functions. Note that properties, classmethods, and classvars +# are also all supported ways of defining functions. + + +# + +egraph = EGraph() + + +@egraph.register +def _(x: Node, y: Node, z: Node) -> Iterable[RewriteOrRule]: + yield rule(x.edge(y)).then(x.path(y)) + yield rule(x.path(y), y.edge(z)).then(x.path(z)) + + +egraph.register( + Node(1).edge(2), + Node(2).edge(3), + Node(3).edge(1), + Node(5).edge(6), +) +egraph +# - + +# Because we defined our nodes using custom expression `sort`, we can "union" two nodes. +# This makes them indistinguishable to rules in `egglog`. + +egraph.register(union(Node(3)).with_(Node(5))) +egraph + +# `union` is a new function here, but it is our old friend: `rewrite`s are implemented as rules whose +# actions are `union`s. For instance, `rewrite(a + b).to(b + a)` is lowered to the following +# rule: +# +# ```python +# rule(a + b == e).then(union(e).with_(b + a))) +# ``` + +# + +egraph.run(run().saturate()) +egraph.check( + Node(3).edge(6), + Node(1).path(6), +) +egraph +# - + +# We can also give a new meaning to equivalence by adding the following rule. + + +@egraph.register +def _(x: Node, y: Node) -> Iterable[RewriteOrRule]: + yield rule(x.path(y), y.path(x)).then(union(x).with_(y)) + + +# This rule says that if there is a path from `x` to `y` and from `y` to `x`, then +# `x` and `y` are equivalent. +# This rule allows us to tell if `a` and `b` are in the same connected component by checking +# `egraph.check(Node(a) == Node(b))`. + +egraph.run(run().saturate()) +egraph.check( + Node(1) == Node(2), + Node(1) == Node(3), + Node(2) == Node(3), +) +egraph diff --git a/docs/tutorials/tut_3_analysis.py b/docs/tutorials/tut_3_analysis.py new file mode 100644 index 00000000..82f8bd1a --- /dev/null +++ b/docs/tutorials/tut_3_analysis.py @@ -0,0 +1,219 @@ +# # 03 - E-class Analysis +# +# _[This tutorial is translated from egglog.](https://egraphs-good.github.io/egglog-tutorial/03-analysis.html)_ +# +# Datalog is a relational language for deductive reasoning. In the last lesson, we write our first +# equality saturation program in egglog, but you can also write rules for deductive reasoning a la Datalog. +# In this lesson, we will write several classic Datalog programs in egglog. One of the benifits +# of egglog being a language for program optimization is that it can talk about terms natively, +# so in egglog we get Datalog with terms for free. +# +# In this lesson, we learn how to combine the power of equality saturation and Datalog. +# We will show how we can define program analyses using Datalog-style deductive reasoning, +# how EqSat-style rewrite rules can make the program analyses more accurate, and how +# accurate program analyses can enable more powerful rewrites. +# +# Our first example will continue with the `path` example in [lesson 2](./tut_2_datalog). +# In this case, there is a path from `e1` to `e2` if `e1` is less than or equal to `e2`. + +# + +# mypy: disable-error-code="empty-body" +from __future__ import annotations +from collections.abc import Iterable +from typing import TypeAlias +from egglog import * + + +class Num(Expr): + # in this example we use big 🐀 to represent numbers + # you can find a list of primitive types in the standard library in [`builtins.py`](https://github.com/egraphs-good/egglog-python/blob/main/python/egglog/builtins.py) + def __init__(self, value: BigRatLike) -> None: ... + + @classmethod + def var(cls, name: StringLike) -> Num: ... + def __add__(self, other: NumLike) -> Num: ... + def __radd__(self, other: NumLike) -> Num: ... + def __mul__(self, other: NumLike) -> Num: ... + def __rmul__(self, other: NumLike) -> Num: ... + def __truediv__(self, other: NumLike) -> Num: ... + def __le__(self, other: NumLike) -> Unit: ... + + @property + def non_zero(self) -> Unit: ... + + +NumLike: TypeAlias = Num | StringLike | BigRatLike +converter(BigRat, Num, Num) +converter(String, Num, Num.var) +# - + +# Let's define some BigRat constants that will be useful later. + +zero = BigRat(0, 1) +one = BigRat(1, 1) +two = BigRat(2, 1) + + +# We define a less-than-or-equal-to relation between two expressions. +# `a.__le__(b)` means that `a <= b` for all possible values of variables. + +# We define rules to deduce the `le` relation. + +egraph = EGraph() + + +@egraph.register +def _( + e1: Num, e2: Num, e3: Num, n1: BigRat, n2: BigRat, x: String, e1a: Num, e1b: Num, e2a: Num, e2b: Num +) -> Iterable[RewriteOrRule]: + # We start with transitivity of `<=`: + yield rule(e1 <= e2, e2 <= e3).then(e1 <= e3) + # Base case for for `Num`: + yield rule(e1 == Num(n1), e2 == Num(n2), n1 <= n2).then(e1 <= e2) + # Base case for `Var`:` + yield rule(e1 == Num.var(x)).then(e1 <= e1) # noqa: PLR0124 + # Recursive case for `Add`: + yield rule( + e1 == (e1a + e1b), + e2 == (e2a + e2b), + e1a <= e2a, + e1b <= e2b, + ).then(e1 <= e2) + + +# Note that we have not defined any rules for multiplication. This would require a more complex +# analysis on the positivity of the expressions. +# +# On the other hand, these rules by themselves are pretty weak. For example, they cannot deduce `x + 1 <= 2 + x`. +# But EqSat-style axiomatic rules make these rules more powerful: + + +@egraph.register +def _(x: Num, y: Num, z: Num, a: BigRat, b: BigRat) -> Iterable[RewriteOrRule]: + yield birewrite(x + (y + z)).to((x + y) + z) + yield birewrite(x * (y * z)).to((x * y) * z) + yield rewrite(x + y).to(y + x) + yield rewrite(x * y).to(y * x) + yield rewrite(x * (y + z)).to((x * y) + (x * z)) + yield rewrite(x + zero).to(x) + yield rewrite(x * one).to(x) + yield rewrite(Num(a) + Num(b)).to(Num(a + b)) + yield rewrite(Num(a) * Num(b)).to(Num(a * b)) + + +# To check our rules + +expr1 = egraph.let("expr1", Num.var("y") + (Num(two) + "x")) +expr2 = egraph.let("expr2", Num.var("x") + Num.var("y") + Num(one) + Num(two)) +egraph.check_fail(expr1 <= expr2) +egraph.run(run().saturate()) +egraph.check(expr1 <= expr2) +egraph + +# A useful special case of the <= analysis is if an expression is upper bounded +# or lower bounded by certain numbers, i.e., interval analysis: + + +# + +@function(merge=lambda old, new: old.min(new)) +def upper_bound(e: Num) -> BigRat: ... + + +@function(merge=lambda old, new: old.max(new)) +def lower_bound(e: Num) -> BigRat: ... + + +# - + +# In the above functions, unlike `<=`, we define upper bound and lower bound as functions from +# expressions to a unique number. +# This is because we are always interested in the tightest upper bound +# and lower bounds, so + + +@egraph.register +def _(e: Num, n: BigRat) -> Iterable[RewriteOrRule]: + yield rule(e <= Num(n)).then(set_(upper_bound(e)).to(n)) + yield rule(Num(n) <= e).then(set_(lower_bound(e)).to(n)) + + +# We can define more specific rules for obtaining the upper and lower bounds of an expression +# based on the upper and lower bounds of its children. + + +@egraph.register +def _(e: Num, e1: Num, e2: Num, u1: BigRat, u2: BigRat, l1: BigRat, l2: BigRat) -> Iterable[RewriteOrRule]: + yield rule( + e == (e1 + e2), + upper_bound(e1) == u1, + upper_bound(e2) == u2, + ).then(set_(upper_bound(e)).to(u1 + u2)) + yield rule( + e == (e1 + e2), + lower_bound(e1) == l1, + lower_bound(e2) == l2, + ).then(set_(lower_bound(e)).to(l1 + l2)) + # ... and the giant rule for multiplication: + yield rule( + e == (e1 * e2), + l1 == lower_bound(e1), + l2 == lower_bound(e2), + u1 == upper_bound(e1), + u2 == upper_bound(e2), + ).then( + set_(lower_bound(e)).to((l1 * l2).min((l1 * u2).min((u1 * l2).min(u1 * u2)))), + set_(upper_bound(e)).to((l1 * l2).max((l1 * u2).max((u1 * l2).max(u1 * u2)))), + ) + # Similarly, + yield rule(e == e1 * e1).then(set_(lower_bound(e)).to(zero)) + + +# The interval analysis is not only useful for numerical tools like [Herbie](https://herbie.uwplse.org/), +# but it can also guard certain optimization rules, making EqSat-based rewriting more powerful! +# +# For example, we are interested in non-zero expressions + + +@egraph.register +def _(e: Num, e2: Num) -> Iterable[RewriteOrRule]: + yield rule(lower_bound(e) > zero).then(e.non_zero) + yield rule(upper_bound(e) < zero).then(e.non_zero) + yield rewrite(e / e).to(Num(one), e.non_zero) + yield rewrite(e * (e2 / e)).to(e2, e.non_zero) + + +# This non-zero analysis lets us optimize expressions that contain division safely. +# 2 * (x / (1 + 2 / 2)) is equivalent to x + +expr3 = egraph.let("expr3", Num(two) * (Num.var("x") / (Num(one) + (Num(two) / Num(two))))) +expr4 = egraph.let("expr4", Num.var("x")) +egraph.check_fail(expr3 == expr4) +egraph.run(run().saturate()) +egraph.check(expr3 == expr4) + +# (x + 1)^2 + 2 + +expr5 = egraph.let("expr5", (Num.var("x") + Num(one)) * (Num.var("x") + Num(one)) + Num(two)) +expr6 = egraph.let("expr6", expr5 / expr5) +egraph.run(run().saturate()) +egraph.check(expr6 == Num(one)) + +# ## Debugging tips! + +# `function_size` is used to return the size of a table and `all_function_sizes` for to return the size of every table. +# This is useful for debugging performance, by seeing how the table sizes evolve as the iteration count increases. + +egraph.function_size(Num.__le__) + +egraph.all_function_sizes() + +# `function_values` extracts every instance of a constructor, function, or relation in the e-graph. +# It takes the maximum number of instances to extract as a second argument, so as not to spend time +# printing millions of rows. `function_values` is particularly useful when debugging small e-graphs. + +list(egraph.function_values(Num.__le__, 15)) + +# `extract_multiple` can also be used to extract that many different "variants" of the +# first argument. This is useful when trying to figure out why one e-class is failing to be unioned with another. + +egraph.extract_multiple(expr3, 3) diff --git a/docs/tutorials/tut_4_scheduling.py b/docs/tutorials/tut_4_scheduling.py new file mode 100644 index 00000000..0e74683c --- /dev/null +++ b/docs/tutorials/tut_4_scheduling.py @@ -0,0 +1,248 @@ +# # 04 - Scheduling +# +# _[This tutorial is translated from egglog.](https://egraphs-good.github.io/egglog-tutorial/04-scheduling.html)_ +# +# In this lesson, we will learn how to use `run-schedule` to improve the performance of egglog. +# We start by using the same language as the previous lesson. + +# mypy: disable-error-code="empty-body" +from __future__ import annotations +from collections.abc import Iterable +from egglog import * +from tut_3_analysis import Num, zero, one, upper_bound, lower_bound, two + + +# ## Rulesets +# +# Different from lesson 3, we organize our rules into "rulesets" +# A ruleset is exactly what it sounds like; a set of rules. +# We can declare rulesets using the `ruleset` method. + +optimizations = ruleset() +analysis = ruleset() + +# We can add rules to rulesets by calling the `register` method on the ruleset instead of the egraph. +# +# We can run rulesets using `run(ruleset)`, or `run()` for running the default ruleset. +# +# Here, we add `<=` rules to the `analysis` ruleset, because they don't add new `Num` nodes to the e-graph. + + +@analysis.register +def _( + e1: Num, e2: Num, e3: Num, n1: BigRat, n2: BigRat, x: String, e1a: Num, e1b: Num, e2a: Num, e2b: Num +) -> Iterable[RewriteOrRule]: + yield rule(e1 <= e2, e2 <= e3).then(e1 <= e3) + yield rule(e1 == Num(n1), e2 == Num(n2), n1 <= n2).then(e1 <= e2) + yield rule(e1 == Num.var(x)).then(e1 <= e1) # noqa: PLR0124 + yield rule( + e1 == (e1a + e1b), + e2 == (e2a + e2b), + e1a <= e2a, + e1b <= e2b, + ).then(e1 <= e2) + + +# In contrast, the following axiomatic rules are doing optimizations, so we add them to the `optimizations` ruleset. + + +@optimizations.register +def _(x: Num, y: Num, z: Num, a: BigRat, b: BigRat) -> Iterable[RewriteOrRule]: + yield birewrite(x + (y + z)).to((x + y) + z) + yield birewrite(x * (y * z)).to((x * y) * z) + yield rewrite(x + y).to(y + x) + yield rewrite(x * y).to(y * x) + yield rewrite(x * (y + z)).to((x * y) + (x * z)) + yield rewrite(x + zero).to(x) + yield rewrite(x * one).to(x) + yield rewrite(Num(a) + Num(b)).to(Num(a + b)) + yield rewrite(Num(a) * Num(b)).to(Num(a * b)) + + +# Here we add the rest of the rules from the last section, but tagged with the appropriate rulesets. + + +@analysis.register +def _(e: Num, n: BigRat, e1: Num, e2: Num, u1: BigRat, u2: BigRat, l1: BigRat, l2: BigRat) -> Iterable[RewriteOrRule]: + yield rule(e <= Num(n)).then(set_(upper_bound(e)).to(n)) + yield rule(Num(n) <= e).then(set_(lower_bound(e)).to(n)) + yield rule( + e == (e1 + e2), + upper_bound(e1) == u1, + upper_bound(e2) == u2, + ).then(set_(upper_bound(e)).to(u1 + u2)) + yield rule( + e == (e1 + e2), + lower_bound(e1) == l1, + lower_bound(e2) == l2, + ).then(set_(lower_bound(e)).to(l1 + l2)) + yield rule( + e == (e1 * e2), + l1 == lower_bound(e1), + l2 == lower_bound(e2), + u1 == upper_bound(e1), + u2 == upper_bound(e2), + ).then( + set_(lower_bound(e)).to((l1 * l2).min((l1 * u2).min((u1 * l2).min(u1 * u2)))), + set_(upper_bound(e)).to((l1 * l2).max((l1 * u2).max((u1 * l2).max(u1 * u2)))), + ) + yield rule(e == e1 * e1).then(set_(lower_bound(e)).to(zero)) + yield rule(lower_bound(e) > zero).then(e.non_zero) + yield rule(upper_bound(e) < zero).then(e.non_zero) + + +# Finally, we have optimization rules that depend on the analysis rules we defined above. + + +@optimizations.register +def _(e: Num, e2: Num) -> Iterable[RewriteOrRule]: + yield rewrite(e / e).to(Num(one), e.non_zero) + yield rewrite(e * (e2 / e)).to(e2, e.non_zero) + + +# Now consider the following program, which consists of a long sequence of additions _inside_ +# a cancelling division. +egraph = EGraph() +addition_chain = egraph.let("addition_chain", "a" + ("b" + ("c" + ("d" + ("e" + Num.var("f")))))) +nonzero_expr = egraph.let("nonzero_expr", Num(one) + (Num(one) + (Num(one) + (Num(one) + Num(two))))) +expr = egraph.let("expr", nonzero_expr * (addition_chain / nonzero_expr)) + +# We want the following check to pass after running the rules. + +egraph.check_fail(expr == addition_chain) + +# To make this check pass, we have to first discover that `nonzero_expr` is indeed non-zero, +# which allows the rule from `x * (y / x)` to `y` to fire. +# On the other hand, if we apply the optimization rules, we risk the exponential blowup from +# the associative and commutative permutations of the `addition_chain`. +# +# Therefore, if we try to run both rulesets directly, egglog will spend lots of effort reassociating and +# commuting the terms in the `addition_chain`, even though the optimization that we actually +# want to run only takes one iteration. However, that optimization requires knowing a fact +# that takes multiple iterations to compute (propagating lower- and upper-bounds +# through `nonzero_expr`). We can build a more efficient *schedule*. + +# ## Schedules + +# Our schedule starts by saturating the analysis rules, fully propagating the `non_zero` information _without_ +# adding any e-nodes to the e-graph. + +egraph.run(analysis.saturate()) + +# Then, just run one iteration of the `optimizations` ruleset. + +egraph.run(optimizations) + +# Or equivalently, +# +# ```python +# egraph.run(analysis.saturate() + optimizations) +# ``` +# +# This makes our check pass +egraph.check(expr == addition_chain) + +# While the above program is effective at optimizing that specific program, it would fail if +# we had a slightly more complex program where we had to interleave the optimizations and analyses +# to derive the optimal program. +# For expressing more complex schedules like these, `egglog` supports a scheduling sub-language, +# with primitives `repeat`, `seq`, `saturate`, and `run`. + + +# The idea behind the following schedule is to always saturate analyses before running optimizations. +# This combination is wrapped in a `repeat` block to give us control over how long to run egglog. +# With `repeat 1` it is the same schedule as before, but now we can increase the iteration +# count if we want to optimize harder with more time and space budget. + +egraph.run((analysis.saturate() + optimizations) * 2) + + +# Running more iterations does not help our above example per se, +# but if we had started with a slightly more complex program to optimize... + +egraph = EGraph() +addition_chain = egraph.let("addition_chain", "a" + ("b" + ("c" + ("d" + ("e" + Num.var("f")))))) +x_times_zero = egraph.let("x_times_zero", Num.var("x") * zero) +nonzero_expr = egraph.let("nonzero_expr", Num(one) + (Num(one) + (Num(one) + (Num(one) + x_times_zero)))) +expr = egraph.let("expr", nonzero_expr * (addition_chain / nonzero_expr)) + +# For the purpose of this example, we add this rule + + +@optimizations.register +def _(x: Num) -> Iterable[RewriteOrRule]: + yield rewrite(x * zero).to(Num(zero)) + + +# To prove `expr` is equivalent to `addition_chain` by applying the cancellation law, +# we need to prove `nonzero_expr` is nonzero, which requires proving +# `x_times_zero`'s bound. +# To show `x_times_zero`'s bound, we need to apply an optimization rule to rewrite +# it to 0. +# In other words, this requires running analyses in between two runs of optimization rules +# (the cancellation law and `*`'s identity law) + +# Therefore, only running our schedule with one iteration (`repeat 1`) does not give us the optimal program. +# Note that here we used the context manager of e-graph, which calls `egraph.push()` and `egraph.pop()` automatically, +# to create a copy of the e-graph to run our schedule on, which is then reverted at the end. + +with egraph: + egraph.run(analysis.saturate() + optimizations) + extracted = egraph.extract(expr) +extracted + +# Instead, we need to increase the iteration number. +with egraph: + egraph.run((analysis.saturate() + optimizations) * 2) + extracted = egraph.extract(expr) +extracted + +# ## Using custom schedulers + +# However, sometimes just having an iteration number does not give you enough control. +# For example, for many rules, such as associativity and commutativity (AC), the size of the e-graph grows hyper-exponentially +# with respect to the number of iterations. + +# Let's go back to this example, and run for five iterations. +# (push) +with egraph: + egraph.run((analysis.saturate() + optimizations) * 5) + assert egraph.function_size(Num.__mul__) == 582 + +# At iteration 5, the `Mul` function has size 582. However, if we bump that to 6, +# the size of the `Mul` function will increase to 13285! Therefore, the iteration number is too coarse +# of a granularity for defining the search space. + +# To this end, egglog provides a scheduler mechanism. A scheduler can decide which matches are important and need to be applied, +# while others can be delayed or skipped. To use scheduler, pass it in as the `scheduler` argument to `run`. +# +# Currently, `egglog-experimental` implements one scheduler, `back_off`. The idea of `back_off` is that it will ban a rule from applying if that rule grows the +# e-graph too fast. The decision to ban is based on a threshold, which is initially small and increases as rules are banned. +# This scheduler works well when the ruleset contains explosive rules like AC. + +# In this example, the back-off scheduler can prevent the associativity rule +# from dominating the equality saturation: when the the associativity rule (or any other rule) is fired too much, +# the scheduler will automatically ban this rule for a few iterations, so that other rules can catch up. + +egraph.run(run(optimizations, scheduler=back_off()) * 10) +egraph.function_size(Num.__mul__) + + +# Note that any scheudler which doesn't have an explicit scope is bound to the outer loop like: +# +# ```python +# bo = back_off() +# egraph.run(bo.scope(run(optimizations, scheduler=bo) * 10)) +# ``` + + +# It is important that the scheduler `bo` is instantiated outside the `repeat` loop, since each scheduler carries some state that is updated +# when run. For example, the following schedule has a very different semantics than the schedule above. +# +# ```python +# bo = back_off() +# egraph.run(bo.scope(run(optimizations, scheduler=bo)) * 10) +# ``` +# +# This schedule instantiates a (fresh) `back-off` scheduler for each `run-with`, so the ten iterations of rulesets are all run +# with the initial configuration of the `back-off` scheduler, which has a very low threshold for banning rules. diff --git a/docs/tutorials/tut_5_extraction.py b/docs/tutorials/tut_5_extraction.py new file mode 100644 index 00000000..b62f4ef5 --- /dev/null +++ b/docs/tutorials/tut_5_extraction.py @@ -0,0 +1,151 @@ +# # 05 - Extraction and Cost +# +# _[This tutorial is translated from egglog.](https://egraphs-good.github.io/egglog-tutorial/05-cost-model-and-extraction.html)_ +# +# In this lesson, we will learn how to use `run-schedule` to improve the performance of egglog. +# We start by using the same language as the previous lesson. + + +# In the previous sessions, we have seen examples of defining and analyzing syntactic terms in egglog. +# After running the rewrite rules, the e-graph may contain a myriad of terms. +# We often want to pick out one or a handful of terms for further processing. +# Extraction is the process of picking out individual terms out of the many terms represented by an e-graph. +# We have seen `extract` command in the previous sessions, which allows us to extract the optimal term from the e-graph. +# +# Optimality needs to be defined with regard to some cost model. +# A cost model is a function that assigns a cost to each term in the e-graph. +# By default, `extract` uses AST size as its cost model and picks the term with the smallest cost. +# +# In this session, we will show several ways of customizing the cost model in egglog. +# Let's first see a simple example of setting costs with the `cost` argument. + + +# Here we have the same `Num`` language but annotated with `cost` keywords. + +# + +# mypy: disable-error-code="empty-body" +from __future__ import annotations +from typing import TypeAlias +from collections.abc import Iterable +from egglog import * + + +class Num(Expr): + def __init__(self, value: i64Like) -> None: ... + + @classmethod + def var(cls, name: StringLike) -> Num: ... + + @method(cost=2) + def __add__(self, other: NumLike) -> Num: ... + @method(cost=10) + def __mul__(self, other: NumLike) -> Num: ... + + # These will be translated to non-reversed ones + def __radd__(self, other: NumLike) -> Num: ... + def __rmul__(self, other: NumLike) -> Num: ... + + +NumLike: TypeAlias = Num | StringLike | i64Like +converter(i64, Num, Num) +converter(String, Num, Num.var) +# - + + +# The default cost of a function is 1. +# Intuitively, the additional `cost` attributes mark the multiplication operation as more expensive than addition. +# +# Let's look at how cost is computed for a concrete term in the default tree cost model. + +egraph = EGraph() +expr = egraph.let("expr", Num.var("x") * 2 + 1) + +# This term has a total cost of 18 because: +# +# ```python +# ( +# ( +# Num.var("x") # cost = 1 (from Num.var) + 1 (from "x") = 2 +# * # cost = 10 (from *) + 2 (from left operand) + 2 (from right operand) = 14 +# Num(2) # cost = 1 (from Num) + 1 (from 2) = 2 +# ) +# + # cost = 2 (from +) + 14 (from left operand) + 2 (from right operand) = 18 +# Num(1) # cost = 1 (from Num) + 1 (from 1) = 2 +# ) +# ``` +# +# +# We can use the `extract` command to extract the lowest cost variant of the term. +# For now it gives the only version that we just defined. We can also pass `include_cost=True` to see the cost of the extracted term. + + +egraph.extract(expr, include_cost=True) + +# Let's introduces more variants with rewrites + + +@egraph.register +def _(x: Num) -> Iterable[RewriteOrRule]: + yield rewrite(x * 2).to(x + x) + + +egraph.run(1) +egraph.extract(expr, include_cost=True) + + +# It now extracts the lower cost variant that correspondes to `x + x + 1`, which is equivalent to the original term. +# If there are multiple variants of the same lowest cost, `extract` break ties arbitrarily. + + +# ## Setting custom cost for e-nodes + +# The `cost` keyword sets an uniform additional cost to each appearance of corresponding constructor. +# However, this is not expressive enough to cover the case where additional cost of an operation is not a fixed constant. +# We can use the `set_cost` feature provided by `egglog-experimental` to get more fine-grained control of individual e-node's cost. + +# To show how this feature works, we define a toy language of matrices. This feature is automatically enabled for +# constructors where it used on. + + +class Matrix(Expr): + def __init__(self, rows: i64Like, cols: i64Like) -> None: ... + def __matmul__(self, other: Matrix) -> Matrix: ... + + # We also define two analyses for the number of rows and columns + @property + def row(self) -> i64: ... + @property + def col(self) -> i64: ... + + +@egraph.register +def _(x: Matrix, y: Matrix, z: Matrix, r: i64, c: i64, m: i64) -> Iterable[RewriteOrRule]: + yield rule(x == Matrix(r, c)).then(set_(x.row).to(r), set_(x.col).to(c)) + yield rule( + x == (y @ z), + r == y.row, + y.col == z.row, + c == z.col, + ).then(set_(x.row).to(r), set_(x.col).to(c)) + + # Now we define the cost of matrix multiplication as a product of the dimensions + yield rule( + y @ z, + r == y.row, + m == y.col, + c == z.col, + ).then(set_cost(y @ z, r * m * c)) + + yield birewrite(x @ (y @ z)).to((x @ y) @ z) + + +# Let's optimize matrix multiplication with this cost model + +Mexpr = egraph.let("Mexpr", (Matrix(64, 8) @ Matrix(8, 256)) @ Matrix(256, 2)) +egraph.run(5) + +# Thanks to our cost model, egglog is able to extract the equivalent program with lowest cost using the dimension information we provided: + +egraph.extract(Mexpr) + +egraph diff --git a/pyproject.toml b/pyproject.toml index 64cb15b1..c6c55d3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ docs = [ "line-profiler", "sphinxcontrib-mermaid", "ablog", + "jupytext", ] @@ -222,6 +223,7 @@ preview = true [tool.ruff.lint.per-file-ignores] # Don't require annotations for tests "python/tests/**" = ["ANN001", "ANN201", "INP001"] +"docs/**" = ["I001", "PLW0131"] # Disable these tests instead for now since ruff doesn't support including all method annotations of decorated class # [tool.ruff.lint.flake8-type-checking] diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index 6064941e..9735495c 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -103,6 +103,10 @@ def value(self) -> str: @method(egg_fn="replace") def replace(self, old: StringLike, new: StringLike) -> String: ... + @method(preserve=True) + def __add__(self, other: StringLike) -> String: + return join(self, other) + StringLike: TypeAlias = String | str diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index d0f451cf..f391c642 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -1036,7 +1036,7 @@ def _serialize( split_primitive_outputs = kwargs.pop("split_primitive_outputs", True) split_functions = kwargs.pop("split_functions", []) include_temporary_functions = kwargs.pop("include_temporary_functions", False) - n_inline_leaves = kwargs.pop("n_inline_leaves", 1) + n_inline_leaves = kwargs.pop("n_inline_leaves", 0) serialized = self._egraph.serialize( [], max_functions=max_functions, diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index f2316ad5..2a5a9a76 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -635,14 +635,22 @@ def _defined_method(self: RuntimeExpr, *args, __name: str = name, **kwargs): for name, r_method in itertools.product(NUMERIC_BINARY_METHODS, (False, True)): + method_name = f"__r{name[2:]}" if r_method else name - def _numeric_binary_method(self: object, other: object, name: str = name, r_method: bool = r_method) -> object: + def _numeric_binary_method( + self: object, other: object, name: str = name, r_method: bool = r_method, method_name: str = method_name + ) -> object: """ Implements numeric binary operations. Tries to find the minimum cost conversion of either the LHS or the RHS, by finding all methods with either the LHS or the RHS as exactly the right type and then upcasting the other to that type. """ + # First check if we have a preserved method for this: + if isinstance(self, RuntimeExpr) and ( + (preserved_method := self.__egg_class_decl__.preserved_methods.get(method_name)) is not None + ): + return preserved_method.__get__(self)(other) # 1. switch if reversed method if r_method: self, other = other, self @@ -668,7 +676,6 @@ def _numeric_binary_method(self: object, other: object, name: str = name, r_meth fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self) return fn(other) - method_name = f"__r{name[2:]}" if r_method else name setattr(RuntimeExpr, method_name, _numeric_binary_method) @@ -688,6 +695,8 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]: ): raise NotImplementedError(f"Can only turn constants or classvars into callable refs, not {expr}") return expr.callable, decl_thunk() + case types.MethodWrapperType() if isinstance((slf := callable.__self__), RuntimeClass): + return MethodRef(slf.__egg_tp__.name, callable.__name__), slf.__egg_decls__ case _: raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref") diff --git a/uv.lock b/uv.lock index 4b17e7c3..4a72cffe 100644 --- a/uv.lock +++ b/uv.lock @@ -712,6 +712,7 @@ dev = [ { name = "anywidget", extra = ["dev"] }, { name = "array-api-compat" }, { name = "jupyterlab" }, + { name = "jupytext" }, { name = "line-profiler" }, { name = "llvmlite" }, { name = "matplotlib" }, @@ -739,6 +740,7 @@ docs = [ { name = "ablog" }, { name = "anywidget" }, { name = "array-api-compat" }, + { name = "jupytext" }, { name = "line-profiler" }, { name = "llvmlite" }, { name = "matplotlib" }, @@ -786,6 +788,7 @@ requires-dist = [ { name = "egglog", extras = ["docs", "test"], marker = "extra == 'dev'" }, { name = "graphviz" }, { name = "jupyterlab", marker = "extra == 'dev'" }, + { name = "jupytext", marker = "extra == 'docs'" }, { name = "line-profiler", marker = "extra == 'docs'" }, { name = "llvmlite", marker = "extra == 'array'", specifier = ">=0.42.0" }, { name = "matplotlib", marker = "extra == 'docs'" }, @@ -819,7 +822,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749 } wheels = [ @@ -1485,6 +1488,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/6a/ca128561b22b60bd5a0c4ea26649e68c8556b82bc70a0c396eebc977fe86/jupyterlab_widgets-3.0.15-py3-none-any.whl", hash = "sha256:d59023d7d7ef71400d51e6fee9a88867f6e65e10a4201605d2d7f3e8f012a31c", size = 216571 }, ] +[[package]] +name = "jupytext" +version = "1.17.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "mdit-py-plugins" }, + { name = "nbformat" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/15/14/41faf71e168fcc6c48268f0fc67ba0d6acf6ee4e2c5c785c2bccb967c29d/jupytext-1.17.3.tar.gz", hash = "sha256:8b6dae76d63c95cad47b493c38f0d9c74491fb621dcd0980abfcac4c8f168679", size = 3753151 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/86/751ec86adb66104d15e650b704f89dddd64ba29283178b9651b9bc84b624/jupytext-1.17.3-py3-none-any.whl", hash = "sha256:09b0a94cd904416e823a5ba9f41bd181031215b6fc682d2b5c18e68354feb17c", size = 166548 }, +] + [[package]] name = "kiwisolver" version = "1.4.9"