Skip to content

Commit b4ac3de

Browse files
Add first tutorial from egglog with jupytext
1 parent aa5f5f2 commit b4ac3de

File tree

5 files changed

+178
-3
lines changed

5 files changed

+178
-3
lines changed

docs/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@
146146
# myst_nb default settings
147147

148148
# Custom formats for reading notebook; suffix -> reader
149-
# nb_custom_formats = {}
150-
149+
# https://github.com/mwouts/jupytext/blob/main/docs/formats-scripts.md#the-light-format
150+
nb_custom_formats = {".py": ["jupytext.reads", {"fmt": "py:light"}]}
151151
# Notebook level metadata key for config overrides
152152
# nb_metadata_key = 'mystnb'
153153

docs/tutorials.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
auto_examples/index
55
tutorials/getting-started
66
tutorials/sklearn
7+
tutorials/tut_1_basics
78
```

docs/tutorials/tut_1_basics.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# # 01 - Basics of Equality Saturation
2+
#
3+
# _[This tutorial is translated from egglog.](https://egraphs-good.github.io/egglog-tutorial/01-basics.html)_
4+
#
5+
# In this tutorial, we will build an optimizer for a subset of linear algebra using egglog.
6+
# We will start by optimizing simple integer arithmetic expressions.
7+
# Our initial DSL supports constants, variables, addition, and multiplication.
8+
9+
# +
10+
# mypy: disable-error-code="empty-body"
11+
from __future__ import annotations
12+
from typing import TypeAlias
13+
from egglog import *
14+
15+
16+
class Num(Expr):
17+
def __init__(self, value: i64Like) -> None: ...
18+
19+
@classmethod
20+
def var(cls, name: StringLike) -> Num: ...
21+
22+
def __add__(self, other: ExprLike) -> Num: ...
23+
def __mul__(self, other: ExprLike) -> Num: ...
24+
25+
# Support inverse operations for convenience
26+
# they will be translated to non-reversed ones
27+
def __radd__(self, other: ExprLike) -> Num: ...
28+
def __rmul__(self, other: ExprLike) -> Num: ...
29+
30+
31+
ExprLike: TypeAlias = Num | StringLike | i64Like
32+
converter(i64, Num, Num)
33+
converter(str, Num, Num.var)
34+
# -
35+
36+
# Now, let's define some simple expressions.
37+
38+
x = Num.var("x")
39+
expr1 = 2 * (x * 3)
40+
expr2 = 6 * x
41+
42+
# You should see an e-graph with two expressions.
43+
44+
egraph = EGraph()
45+
egraph.register(expr1, expr2)
46+
egraph
47+
48+
# We can print the values of the expressions as well to see their fully expanded forms.
49+
50+
String("Hello, world!")
51+
52+
i64(42)
53+
54+
expr1
55+
56+
expr2
57+
58+
# We can use the `check` commands to check properties of our e-graph.
59+
60+
x, y = vars_("x y", Num)
61+
assert egraph.check_bool(expr1 == x * y)
62+
63+
# This checks if `expr1` is equivalent to some expression `x * y`, where `x` and `y` are
64+
# variables that can be mapped to any `Num` expression in the e-graph.
65+
#
66+
# Checks can fail. For example the following check fails because `expr1` is not equivalent to
67+
# `x + y` for any `x` and `y` in the e-graph.
68+
69+
assert not egraph.check_bool(expr1 == x + y)
70+
71+
# Let us define some rewrite rules over our small DSL.
72+
73+
egraph.register(rewrite(x + y).to(y + x))
74+
75+
# This rule asserts that addition is commutative. More concretely, this rules says, if the e-graph
76+
# contains expressions of the form `x + y`, then the e-graph should also contain the
77+
# expression `y + x`, and they should be equivalent.
78+
#
79+
# Similarly, we can define the associativity rule for addition.
80+
81+
z = var("z", Num)
82+
egraph.register(rewrite(x + (y + z)).to((x + y) + z))
83+
84+
# This rule says, if the e-graph contains expressions of the form `x + (y + z)`, then the e-graph should also contain
85+
# the expression `(x + y) + z`, and they should be equivalent.
86+
87+
# There are two subtleties to rules:
88+
#
89+
# 1. Defining a rule is different from running it. The following check would fail at this point
90+
# because the commutativity rule has not been run (we've inserted `x + 3` but not yet derived `3 + x`).
91+
92+
assert not egraph.check_bool((x + 3) == (3 + x))
93+
94+
# 2. Rules are not instantiated for every possible term; they are only instantiated for terms that are
95+
# in the e-graph. For instance, even if we ran the commutativity rule above, the following check would
96+
# still fail because the e-graph does not contain either of the terms `Num(-2) + Num(2)` or `Num(2) + Num(-2)`.
97+
98+
assert not egraph.check_bool(Num(-2) + 2 == Num(2) + -2)
99+
100+
# Let's also define commutativity and associativity for multiplication.
101+
102+
egraph.register(
103+
rewrite(x * y).to(y * x),
104+
rewrite(x * (y * z)).to((x * y) * z),
105+
)
106+
107+
# `egglog` also defines a set of built-in functions over primitive types, such as `+` and `*`,
108+
# and supports operator overloading, so the same operator can be used with different types.
109+
110+
egraph.extract(i64(1) + 2)
111+
112+
egraph.extract(String("1") + "2")
113+
114+
egraph.extract(f64(1.0) + 2.0)
115+
116+
# With primitives, we can define rewrite rules that talk about the semantics of operators.
117+
# The following rules show constant folding over addition and multiplication.
118+
119+
a, b = vars_("a b", i64)
120+
egraph.register(
121+
rewrite(Num(a) + Num(b)).to(Num(a + b)),
122+
rewrite(Num(a) * Num(b)).to(Num(a * b)),
123+
)
124+
125+
# While we have defined several rules, the e-graph has not changed since we inserted the two
126+
# expressions. To run rules we have defined so far, we can use `run`.
127+
128+
egraph.run(10)
129+
130+
# This tells `egglog` to run our rules for 10 iterations. More precisely, egglog runs the
131+
# following pseudo code:
132+
#
133+
# ```
134+
# for i in range(10):
135+
# for r in rules:
136+
# ms = r.find_matches(egraph)
137+
# for m in ms:
138+
# egraph = egraph.apply_rule(r, m)
139+
# egraph = rebuild(egraph)
140+
# ```
141+
#
142+
# In other words, `egglog` computes all the matches for one iteration before making any
143+
# updates to the e-graph. This is in contrast to an evaluation model where rules are immediately
144+
# applied and the matches are obtained on demand over a changing e-graph.
145+
#
146+
# We can now look at the e-graph and see that that `2 * (x + 3)` and `6 + (2 * x)` are now in the same E-class.
147+
148+
egraph
149+
150+
# We can also check this fact explicitly
151+
152+
egraph.check(expr1 == expr2)

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ docs = [
7070
"line-profiler",
7171
"sphinxcontrib-mermaid",
7272
"ablog",
73+
"jupytext",
7374
]
7475

7576

@@ -222,6 +223,7 @@ preview = true
222223
[tool.ruff.lint.per-file-ignores]
223224
# Don't require annotations for tests
224225
"python/tests/**" = ["ANN001", "ANN201", "INP001"]
226+
"docs/**" = ["I001", "PLW0131"]
225227

226228
# Disable these tests instead for now since ruff doesn't support including all method annotations of decorated class
227229
# [tool.ruff.lint.flake8-type-checking]

uv.lock

Lines changed: 21 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)