1010# mypy: disable-error-code="empty-body"
1111from __future__ import annotations
1212from typing import TypeAlias
13+ from collections .abc import Iterable
1314from egglog import *
1415
1516
@@ -19,67 +20,94 @@ def __init__(self, value: i64Like) -> None: ...
1920 @classmethod
2021 def var (cls , name : StringLike ) -> Num : ...
2122
22- def __add__ (self , other : ExprLike ) -> Num : ...
23- def __mul__ (self , other : ExprLike ) -> Num : ...
23+ def __add__ (self , other : NumLike ) -> Num : ...
24+ def __mul__ (self , other : NumLike ) -> Num : ...
2425
2526 # Support inverse operations for convenience
2627 # they will be translated to non-reversed ones
27- def __radd__ (self , other : ExprLike ) -> Num : ...
28- def __rmul__ (self , other : ExprLike ) -> Num : ...
28+ def __radd__ (self , other : NumLike ) -> Num : ...
29+ def __rmul__ (self , other : NumLike ) -> Num : ...
2930
3031
31- ExprLike : TypeAlias = Num | StringLike | i64Like
32- converter (i64 , Num , Num )
33- converter (str , Num , Num .var )
32+ NumLike : TypeAlias = Num | StringLike | i64Like
3433# -
3534
35+
36+ # The signature here takes `NumLike` not `Num` so that you can write `Num(1) + 2` instead of
37+ # `Num(1) + Num(2)`. This is helpful for ease of use and also for compatibility when you are trying to
38+ # create expressions that act like Python objects which perform upcasting.
39+ #
40+ # To support this, you must define conversions between primitive types and your expression types.
41+ # When a value is passed into a function, it will find the type it should be converted to and
42+ # transitively apply the conversions you have defined:
43+
44+ converter (i64 , Num , Num )
45+ converter (String , Num , Num .var )
46+
3647# Now, let's define some simple expressions.
3748
49+ egraph = EGraph ()
3850x = Num .var ("x" )
39- expr1 = 2 * (x * 3 )
40- expr2 = 6 * x
51+ expr1 = egraph . let ( "expr1" , 2 * (x * 3 ) )
52+ expr2 = egraph . let ( "expr2" , 6 * x )
4153
4254# You should see an e-graph with two expressions.
4355
44- egraph = EGraph ()
45- egraph .register (expr1 , expr2 )
4656egraph
4757
48- # We can print the values of the expressions as well to see their fully expanded forms.
58+ # We can `.extract` the values of the expressions as well to see their fully expanded forms.
4959
50- String ("Hello, world!" )
60+ egraph . extract ( String ("Hello, world!" ) )
5161
52- i64 (42 )
62+ egraph . extract ( i64 (42 ) )
5363
54- expr1
64+ egraph . extract ( expr1 )
5565
56- expr2
66+ egraph . extract ( expr2 )
5767
5868# We can use the `check` commands to check properties of our e-graph.
5969
6070x , y = vars_ ("x y" , Num )
61- assert egraph .check_bool (expr1 == x * y )
71+ egraph .check (expr1 == x * y )
6272
6373# This checks if `expr1` is equivalent to some expression `x * y`, where `x` and `y` are
6474# variables that can be mapped to any `Num` expression in the e-graph.
6575#
6676# Checks can fail. For example the following check fails because `expr1` is not equivalent to
6777# `x + y` for any `x` and `y` in the e-graph.
6878
69- assert not egraph .check_bool (expr1 == x + y )
79+ egraph .check_fail (expr1 == x + y )
7080
7181# Let us define some rewrite rules over our small DSL.
7282
73- egraph .register (rewrite (x + y ).to (y + x ))
83+
84+ @egraph .register
85+ def _add_comm (x : Num , y : Num ):
86+ yield rewrite (x + y ).to (y + x )
87+
88+
89+ # This could also been written like:
90+ #
91+ # ```python
92+ # x, y = vars_("x y", Num)
93+ # egraph.register(rewrite(x + y).to(y + x))
94+ # ```
95+ #
96+ # In this tutorial we will use the function form to define rewrites and rules, because then then we only
97+ # have to write the variable names once as arguments and they are not leaked to the outer scope.
98+
7499
75100# This rule asserts that addition is commutative. More concretely, this rules says, if the e-graph
76101# contains expressions of the form `x + y`, then the e-graph should also contain the
77102# expression `y + x`, and they should be equivalent.
78103#
79104# Similarly, we can define the associativity rule for addition.
80105
81- z = var ("z" , Num )
82- egraph .register (rewrite (x + (y + z )).to ((x + y ) + z ))
106+
107+ @egraph .register
108+ def _add_assoc (x : Num , y : Num , z : Num ) -> Iterable [RewriteOrRule ]:
109+ yield rewrite (x + (y + z )).to ((x + y ) + z )
110+
83111
84112# This rule says, if the e-graph contains expressions of the form `x + (y + z)`, then the e-graph should also contain
85113# the expression `(x + y) + z`, and they should be equivalent.
@@ -89,20 +117,22 @@ def __rmul__(self, other: ExprLike) -> Num: ...
89117# 1. Defining a rule is different from running it. The following check would fail at this point
90118# because the commutativity rule has not been run (we've inserted `x + 3` but not yet derived `3 + x`).
91119
92- assert not egraph .check_bool ((x + 3 ) == (3 + x ))
120+ egraph .check_fail ((x + 3 ) == (3 + x ))
93121
94122# 2. Rules are not instantiated for every possible term; they are only instantiated for terms that are
95123# in the e-graph. For instance, even if we ran the commutativity rule above, the following check would
96124# still fail because the e-graph does not contain either of the terms `Num(-2) + Num(2)` or `Num(2) + Num(-2)`.
97125
98- assert not egraph .check_bool (Num (- 2 ) + 2 == Num (2 ) + - 2 )
126+ egraph .check_fail (Num (- 2 ) + 2 == Num (2 ) + - 2 )
99127
100128# Let's also define commutativity and associativity for multiplication.
101129
102- egraph .register (
103- rewrite (x * y ).to (y * x ),
104- rewrite (x * (y * z )).to ((x * y ) * z ),
105- )
130+
131+ @egraph .register
132+ def _mul (x : Num , y : Num , z : Num ) -> Iterable [RewriteOrRule ]:
133+ yield rewrite (x * y ).to (y * x )
134+ yield rewrite (x * (y * z )).to ((x * y ) * z )
135+
106136
107137# `egglog` also defines a set of built-in functions over primitive types, such as `+` and `*`,
108138# and supports operator overloading, so the same operator can be used with different types.
@@ -116,11 +146,12 @@ def __rmul__(self, other: ExprLike) -> Num: ...
116146# With primitives, we can define rewrite rules that talk about the semantics of operators.
117147# The following rules show constant folding over addition and multiplication.
118148
119- a , b = vars_ ("a b" , i64 )
120- egraph .register (
121- rewrite (Num (a ) + Num (b )).to (Num (a + b )),
122- rewrite (Num (a ) * Num (b )).to (Num (a * b )),
123- )
149+
150+ @egraph .register
151+ def _const_fold (a : i64 , b : i64 ) -> Iterable [RewriteOrRule ]:
152+ yield rewrite (Num (a ) + Num (b )).to (Num (a + b ))
153+ yield rewrite (Num (a ) * Num (b )).to (Num (a * b ))
154+
124155
125156# While we have defined several rules, the e-graph has not changed since we inserted the two
126157# expressions. To run rules we have defined so far, we can use `run`.
0 commit comments