Skip to content

Commit 984daca

Browse files
Add egglog tutorials
1 parent 676812b commit 984daca

File tree

7 files changed

+941
-32
lines changed

7 files changed

+941
-32
lines changed

docs/tutorials.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,8 @@ auto_examples/index
55
tutorials/getting-started
66
tutorials/sklearn
77
tutorials/tut_1_basics
8+
tutorials/tut_2_datalog
9+
tutorials/tut_3_analysis
10+
tutorials/tut_4_scheduling
11+
tutorials/tut_5_extraction
812
```

docs/tutorials/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""
2+
Tutorials
3+
"""

docs/tutorials/tut_1_basics.py

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# mypy: disable-error-code="empty-body"
1111
from __future__ import annotations
1212
from typing import TypeAlias
13+
from collections.abc import Iterable
1314
from egglog import *
1415

1516

@@ -19,67 +20,94 @@ def __init__(self, value: i64Like) -> None: ...
1920
@classmethod
2021
def var(cls, name: StringLike) -> Num: ...
2122

22-
def __add__(self, other: ExprLike) -> Num: ...
23-
def __mul__(self, other: ExprLike) -> Num: ...
23+
def __add__(self, other: NumLike) -> Num: ...
24+
def __mul__(self, other: NumLike) -> Num: ...
2425

2526
# Support inverse operations for convenience
2627
# they will be translated to non-reversed ones
27-
def __radd__(self, other: ExprLike) -> Num: ...
28-
def __rmul__(self, other: ExprLike) -> Num: ...
28+
def __radd__(self, other: NumLike) -> Num: ...
29+
def __rmul__(self, other: NumLike) -> Num: ...
2930

3031

31-
ExprLike: TypeAlias = Num | StringLike | i64Like
32-
converter(i64, Num, Num)
33-
converter(str, Num, Num.var)
32+
NumLike: TypeAlias = Num | StringLike | i64Like
3433
# -
3534

35+
36+
# The signature here takes `NumLike` not `Num` so that you can write `Num(1) + 2` instead of
37+
# `Num(1) + Num(2)`. This is helpful for ease of use and also for compatibility when you are trying to
38+
# create expressions that act like Python objects which perform upcasting.
39+
#
40+
# To support this, you must define conversions between primitive types and your expression types.
41+
# When a value is passed into a function, it will find the type it should be converted to and
42+
# transitively apply the conversions you have defined:
43+
44+
converter(i64, Num, Num)
45+
converter(String, Num, Num.var)
46+
3647
# Now, let's define some simple expressions.
3748

49+
egraph = EGraph()
3850
x = Num.var("x")
39-
expr1 = 2 * (x * 3)
40-
expr2 = 6 * x
51+
expr1 = egraph.let("expr1", 2 * (x * 3))
52+
expr2 = egraph.let("expr2", 6 * x)
4153

4254
# You should see an e-graph with two expressions.
4355

44-
egraph = EGraph()
45-
egraph.register(expr1, expr2)
4656
egraph
4757

48-
# We can print the values of the expressions as well to see their fully expanded forms.
58+
# We can `.extract` the values of the expressions as well to see their fully expanded forms.
4959

50-
String("Hello, world!")
60+
egraph.extract(String("Hello, world!"))
5161

52-
i64(42)
62+
egraph.extract(i64(42))
5363

54-
expr1
64+
egraph.extract(expr1)
5565

56-
expr2
66+
egraph.extract(expr2)
5767

5868
# We can use the `check` commands to check properties of our e-graph.
5969

6070
x, y = vars_("x y", Num)
61-
assert egraph.check_bool(expr1 == x * y)
71+
egraph.check(expr1 == x * y)
6272

6373
# This checks if `expr1` is equivalent to some expression `x * y`, where `x` and `y` are
6474
# variables that can be mapped to any `Num` expression in the e-graph.
6575
#
6676
# Checks can fail. For example the following check fails because `expr1` is not equivalent to
6777
# `x + y` for any `x` and `y` in the e-graph.
6878

69-
assert not egraph.check_bool(expr1 == x + y)
79+
egraph.check_fail(expr1 == x + y)
7080

7181
# Let us define some rewrite rules over our small DSL.
7282

73-
egraph.register(rewrite(x + y).to(y + x))
83+
84+
@egraph.register
85+
def _add_comm(x: Num, y: Num):
86+
yield rewrite(x + y).to(y + x)
87+
88+
89+
# This could also been written like:
90+
#
91+
# ```python
92+
# x, y = vars_("x y", Num)
93+
# egraph.register(rewrite(x + y).to(y + x))
94+
# ```
95+
#
96+
# In this tutorial we will use the function form to define rewrites and rules, because then then we only
97+
# have to write the variable names once as arguments and they are not leaked to the outer scope.
98+
7499

75100
# This rule asserts that addition is commutative. More concretely, this rules says, if the e-graph
76101
# contains expressions of the form `x + y`, then the e-graph should also contain the
77102
# expression `y + x`, and they should be equivalent.
78103
#
79104
# Similarly, we can define the associativity rule for addition.
80105

81-
z = var("z", Num)
82-
egraph.register(rewrite(x + (y + z)).to((x + y) + z))
106+
107+
@egraph.register
108+
def _add_assoc(x: Num, y: Num, z: Num) -> Iterable[RewriteOrRule]:
109+
yield rewrite(x + (y + z)).to((x + y) + z)
110+
83111

84112
# This rule says, if the e-graph contains expressions of the form `x + (y + z)`, then the e-graph should also contain
85113
# the expression `(x + y) + z`, and they should be equivalent.
@@ -89,20 +117,22 @@ def __rmul__(self, other: ExprLike) -> Num: ...
89117
# 1. Defining a rule is different from running it. The following check would fail at this point
90118
# because the commutativity rule has not been run (we've inserted `x + 3` but not yet derived `3 + x`).
91119

92-
assert not egraph.check_bool((x + 3) == (3 + x))
120+
egraph.check_fail((x + 3) == (3 + x))
93121

94122
# 2. Rules are not instantiated for every possible term; they are only instantiated for terms that are
95123
# in the e-graph. For instance, even if we ran the commutativity rule above, the following check would
96124
# still fail because the e-graph does not contain either of the terms `Num(-2) + Num(2)` or `Num(2) + Num(-2)`.
97125

98-
assert not egraph.check_bool(Num(-2) + 2 == Num(2) + -2)
126+
egraph.check_fail(Num(-2) + 2 == Num(2) + -2)
99127

100128
# Let's also define commutativity and associativity for multiplication.
101129

102-
egraph.register(
103-
rewrite(x * y).to(y * x),
104-
rewrite(x * (y * z)).to((x * y) * z),
105-
)
130+
131+
@egraph.register
132+
def _mul(x: Num, y: Num, z: Num) -> Iterable[RewriteOrRule]:
133+
yield rewrite(x * y).to(y * x)
134+
yield rewrite(x * (y * z)).to((x * y) * z)
135+
106136

107137
# `egglog` also defines a set of built-in functions over primitive types, such as `+` and `*`,
108138
# and supports operator overloading, so the same operator can be used with different types.
@@ -116,11 +146,12 @@ def __rmul__(self, other: ExprLike) -> Num: ...
116146
# With primitives, we can define rewrite rules that talk about the semantics of operators.
117147
# The following rules show constant folding over addition and multiplication.
118148

119-
a, b = vars_("a b", i64)
120-
egraph.register(
121-
rewrite(Num(a) + Num(b)).to(Num(a + b)),
122-
rewrite(Num(a) * Num(b)).to(Num(a * b)),
123-
)
149+
150+
@egraph.register
151+
def _const_fold(a: i64, b: i64) -> Iterable[RewriteOrRule]:
152+
yield rewrite(Num(a) + Num(b)).to(Num(a + b))
153+
yield rewrite(Num(a) * Num(b)).to(Num(a * b))
154+
124155

125156
# While we have defined several rules, the e-graph has not changed since we inserted the two
126157
# expressions. To run rules we have defined so far, we can use `run`.

0 commit comments

Comments
 (0)