Skip to content

Commit 99d3eca

Browse files
Merge pull request #226 from egraphs-good/generate-egg
Be able to generate egg from python example
2 parents f1190bf + e492206 commit 99d3eca

File tree

8 files changed

+114
-68
lines changed

8 files changed

+114
-68
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ jobs:
6969
- uses: CodSpeedHQ/action@v3
7070
with:
7171
token: ${{ secrets.CODSPEED_TOKEN }}
72-
run: uv run pytest -vvv -n auto
72+
# allow updating snapshots due to indeterministic benchmarks
73+
run: uv run pytest -vvv -n auto --snapshot-update
7374

7475
docs:
7576
runs-on: ubuntu-latest

docs/changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ _This project uses semantic versioning_
55
## UNRELEASED
66

77
- Upgrade dependencies including [egglog](https://github.com/egraphs-good/egglog/compare/saulshanabrook:egg-smol:a555b2f5e82c684442775cc1a5da94b71930113c...b0db06832264c9b22694bd3de2bdacd55bbe9e32)
8+
- Fix bug with non glob star import
9+
- Fix bug extracting functions
810

911
## 8.0.0 (2024-10-17)
1012

python/egglog/egraph_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def from_expr(self, tp: JustTypeRef, term: bindings._Term) -> TypedExprDecl:
493493
if term.name == "py-object":
494494
call = bindings.termdag_term_to_expr(self.termdag, term)
495495
expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call))
496-
if term.name == "unstable-fn":
496+
elif term.name == "unstable-fn":
497497
# Get function name
498498
fn_term, *arg_terms = term.args
499499
fn_value = self.resolve_term(fn_term, JustTypeRef("String"))

python/egglog/exp/array_api_jit.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,17 @@ def jit(fn: X) -> X:
1717
# 1. Create variables for each of the two args in the functions
1818
sig = inspect.signature(fn)
1919
arg1, arg2 = sig.parameters.keys()
20-
21-
with EGraph() as egraph:
20+
egraph = EGraph()
21+
with egraph:
2222
res = fn(NDArray.var(arg1), NDArray.var(arg2))
2323
egraph.register(res)
2424
egraph.run(array_api_numba_schedule)
2525
res_optimized = egraph.extract(res)
26-
egraph.display(split_primitive_outputs=True, n_inline_leaves=3)
26+
# egraph.display(split_primitive_outputs=True, n_inline_leaves=3)
2727

28-
egraph = EGraph()
2928
fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2))
3029
egraph.register(fn_program)
3130
egraph.run(array_api_program_gen_schedule)
32-
fn = cast(X, egraph.eval(fn_program.py_object))
31+
fn = cast(X, egraph.eval(egraph.extract(fn_program.py_object)))
3332
fn.expr = res_optimized # type: ignore[attr-defined]
3433
return fn

python/egglog/exp/array_api_program_gen.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# mypy: disable-error-code="empty-body"
22
from __future__ import annotations
33

4-
import numpy as np
5-
64
from egglog import *
75

86
from .array_api import *
@@ -13,9 +11,12 @@
1311
# Depends on `np` as a global variable.
1412
##
1513

16-
array_api_program_gen_ruleset = ruleset()
14+
array_api_program_gen_ruleset = ruleset(name="array_api_program_gen_ruleset")
15+
array_api_program_gen_eval_ruleset = ruleset(name="array_api_program_gen_eval_ruleset")
1716

18-
array_api_program_gen_schedule = array_api_program_gen_ruleset.saturate() + program_gen_ruleset.saturate()
17+
array_api_program_gen_schedule = (
18+
array_api_program_gen_ruleset | program_gen_ruleset | array_api_program_gen_eval_ruleset | eval_program_rulseset
19+
).saturate()
1920

2021

2122
@function
@@ -98,17 +99,14 @@ def _tuple_int_program(i: Int, ti: TupleInt, k: i64, idx_fn: Callable[[Int], Int
9899
def ndarray_program(x: NDArray) -> Program: ...
99100

100101

101-
@function
102-
def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> Program: ...
102+
@function(ruleset=array_api_program_gen_ruleset)
103+
def ndarray_function_two_program(res: NDArray, l: NDArray, r: NDArray) -> Program:
104+
return ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))
103105

104106

105-
@array_api_program_gen_ruleset.register
106-
def _ndarray_function_two(f: Program, res: NDArray, l: NDArray, r: NDArray, o: PyObject):
107-
# When we have function, set the program and trigger it to be compiled
108-
yield rule(eq(f).to(ndarray_function_two(res, l, r))).then(
109-
union(f).with_(ndarray_program(res).function_two(ndarray_program(l), ndarray_program(r))),
110-
f.eval_py_object({"np": np}),
111-
)
107+
@function(ruleset=array_api_program_gen_eval_ruleset)
108+
def ndarray_function_two(res: NDArray, l: NDArray, r: NDArray) -> EvalProgram:
109+
return EvalProgram(ndarray_function_two_program(res, l, r), {"np": np})
112110

113111

114112
@function

python/egglog/exp/program_gen.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,18 @@ def parent(self) -> Program:
8383
Only keeps the original parent, not any additional ones, so that each set of statements is only added once.
8484
"""
8585

86-
@method(default=Unit())
87-
def eval_py_object(self, globals: object) -> Unit:
86+
@property
87+
def is_identifer(self) -> Bool:
88+
"""
89+
Returns whether the expression is an identifier. Used so that we don't re-assign any identifiers.
90+
"""
91+
92+
93+
converter(String, Program, Program)
94+
95+
96+
class EvalProgram(Expr):
97+
def __init__(self, program: Program, globals: object) -> None:
8898
"""
8999
Evaluates the program and saves as the py_object
90100
"""
@@ -98,38 +108,34 @@ def py_object(self) -> PyObject:
98108
"""
99109

100110
@property
101-
def is_identifer(self) -> Bool:
111+
def statements(self) -> String:
102112
"""
103-
Returns whether the expression is an identifier. Used so that we don't re-assign any identifiers.
113+
Returns the statements of the program, if it's been compiled
104114
"""
105115

106116

107-
converter(String, Program, Program)
108-
109-
program_gen_ruleset = ruleset()
110-
111-
112-
@program_gen_ruleset.register
113-
def _py_object(p: Program, expr: String, statements: String, g: PyObject):
117+
@ruleset
118+
def eval_program_rulseset(ep: EvalProgram, p: Program, expr: String, statements: String, g: PyObject):
114119
# When we evaluate a program, we first want to compile to a string
115-
yield rule(p.eval_py_object(g)).then(p.compile())
120+
yield rule(EvalProgram(p, g)).then(p.compile())
116121
# Then we want to evaluate the statements/expr
117122
yield rule(
118-
p.eval_py_object(g),
123+
eq(ep).to(EvalProgram(p, g)),
119124
eq(p.statements).to(statements),
120125
eq(p.expr).to(expr),
121126
).then(
122-
set_(p.py_object).to(
127+
set_(ep.py_object).to(
123128
py_eval(
124129
"l['___res']",
125130
PyObject.dict(PyObject.from_string("l"), py_exec(join(statements, "\n", "___res = ", expr), g)),
126131
)
127-
)
132+
),
133+
set_(ep.statements).to(statements),
128134
)
129135

130136

131-
@program_gen_ruleset.register
132-
def _compile(
137+
@ruleset
138+
def program_gen_ruleset(
133139
s: String,
134140
s1: String,
135141
s2: String,

python/tests/test_array_api.py

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
1010

1111
from egglog.exp.array_api import *
12+
from egglog.exp.array_api_jit import jit
1213
from egglog.exp.array_api_numba import array_api_numba_schedule
1314
from egglog.exp.array_api_program_gen import *
1415

@@ -103,51 +104,69 @@ def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
103104
return globals[var]
104105

105106

106-
def load_source(expr, egraph: EGraph):
107-
with egraph:
108-
fn_program = egraph.let("fn_program", ndarray_function_two(expr, NDArray.var("X"), NDArray.var("y")))
109-
egraph.run(array_api_program_gen_schedule)
110-
return egraph.eval(egraph.extract(fn_program.statements))
107+
def load_source(fn_program: EvalProgram, egraph: EGraph):
108+
egraph.register(fn_program)
109+
egraph.run(array_api_program_gen_schedule)
110+
# dp the needed pieces in here for benchmarking
111+
return egraph.eval(egraph.extract(fn_program.py_object))
111112

112113

113-
def trace_lda(egraph: EGraph):
114-
X_arr = NDArray.var("X")
115-
assume_dtype(X_arr, X_np.dtype)
116-
assume_shape(X_arr, X_np.shape)
117-
assume_isfinite(X_arr)
114+
def lda(X, y):
115+
assume_dtype(X, X_np.dtype)
116+
assume_shape(X, X_np.shape)
117+
assume_isfinite(X)
118118

119-
y_arr = NDArray.var("y")
120-
assume_dtype(y_arr, y_np.dtype)
121-
assume_shape(y_arr, y_np.shape)
122-
assume_value_one_of(y_arr, tuple(map(int, np.unique(y_np)))) # type: ignore[arg-type]
119+
assume_dtype(y, y_np.dtype)
120+
assume_shape(y, y_np.shape)
121+
assume_value_one_of(y, tuple(map(int, np.unique(y_np)))) # type: ignore[arg-type]
122+
return run_lda(X, y)
123123

124-
with egraph:
125-
return run_lda(X_arr, y_arr)
124+
125+
def simplify_lda(egraph: EGraph, expr: NDArray) -> NDArray:
126+
egraph.register(expr)
127+
egraph.run(array_api_numba_schedule)
128+
return egraph.extract(expr)
126129

127130

128131
@pytest.mark.benchmark(min_rounds=3)
129132
class TestLDA:
133+
"""
134+
Incrementally benchmark each part of the LDA to see how long it takes to run.
135+
"""
136+
130137
def test_trace(self, snapshot_py, benchmark):
131-
X_r2 = benchmark(trace_lda, EGraph())
138+
X = NDArray.var("X")
139+
y = NDArray.var("y")
140+
with EGraph():
141+
X_r2 = benchmark(lda, X, y)
132142
assert str(X_r2) == snapshot_py
133143

134144
def test_optimize(self, snapshot_py, benchmark):
135145
egraph = EGraph()
136-
expr = trace_lda(egraph)
137-
simplified = benchmark(egraph.simplify, expr, array_api_numba_schedule)
146+
X = NDArray.var("X")
147+
y = NDArray.var("y")
148+
with egraph:
149+
expr = lda(X, y)
150+
simplified = benchmark(simplify_lda, egraph, expr)
138151
assert str(simplified) == snapshot_py
139152

140-
@pytest.mark.xfail(reason="Original source is not working")
141-
def test_source(self, snapshot_py, benchmark):
142-
egraph = EGraph()
143-
expr = trace_lda(egraph)
144-
assert benchmark(load_source, expr, egraph) == snapshot_py
153+
# @pytest.mark.xfail(reason="Original source is not working")
154+
# def test_source(self, snapshot_py, benchmark):
155+
# egraph = EGraph()
156+
# expr = trace_lda(egraph)
157+
# assert benchmark(load_source, expr, egraph) == snapshot_py
145158

146159
def test_source_optimized(self, snapshot_py, benchmark):
147160
egraph = EGraph()
148-
expr = trace_lda(egraph)
149-
optimized_expr = egraph.simplify(expr, array_api_numba_schedule)
150-
assert benchmark(load_source, optimized_expr, egraph) == snapshot_py
161+
X = NDArray.var("X")
162+
y = NDArray.var("y")
163+
with egraph:
164+
expr = lda(X, y)
165+
optimized_expr = simplify_lda(egraph, expr)
166+
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
167+
py_object = benchmark(load_source, fn_program, egraph)
168+
assert np.allclose(py_object(X_np, y_np), res_np)
169+
assert egraph.eval(fn_program.statements) == snapshot_py
151170

152171
@pytest.mark.parametrize(
153172
"fn",
@@ -156,9 +175,29 @@ def test_source_optimized(self, snapshot_py, benchmark):
156175
pytest.param(run_lda, id="array_api"),
157176
pytest.param(_load_py_snapshot(test_source_optimized, "__fn"), id="array_api-optimized"),
158177
pytest.param(numba.njit(_load_py_snapshot(test_source_optimized, "__fn")), id="array_api-optimized-numba"),
178+
pytest.param(jit(lda), id="array_api-jit"),
159179
],
160180
)
161181
def test_execution(self, fn, benchmark):
162182
# warmup once for numba
163183
assert np.allclose(res_np, fn(X_np, y_np))
164184
benchmark(fn, X_np, y_np)
185+
186+
187+
# if calling as script, print out egglog source for test
188+
# similar to jit, but don't include pyobject parts so it works in vanilla egglog
189+
if __name__ == "__main__":
190+
print("Generating egglog source for test")
191+
egraph = EGraph(save_egglog_string=True)
192+
X_ = NDArray.var("X")
193+
y_ = NDArray.var("y")
194+
with egraph:
195+
expr = lda(X_, y_)
196+
optimized_expr = egraph.simplify(expr, array_api_numba_schedule)
197+
fn_program = ndarray_function_two_program(optimized_expr, X_, y_)
198+
egraph.register(fn_program.compile())
199+
egraph.run(array_api_program_gen_ruleset.saturate() + program_gen_ruleset.saturate())
200+
egraph.extract(fn_program.statements)
201+
name = "python.egg"
202+
print("Saving to", name)
203+
Path(name).write_text(egraph.as_egglog_string)

python/tests/test_program_gen.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_to_string(snapshot_py) -> None:
5555
egraph = EGraph()
5656
egraph.register(fn)
5757
egraph.register(fn.compile())
58-
egraph.run(to_program_ruleset * 100 + program_gen_ruleset * 200)
58+
egraph.run((to_program_ruleset | program_gen_ruleset).saturate())
5959
# egraph.display(n_inline_leaves=1)
6060
assert egraph.eval(fn.expr) == "my_fn"
6161
assert egraph.eval(fn.statements) == snapshot_py
@@ -67,8 +67,9 @@ def test_py_object():
6767
z = Math.var("z")
6868
fn = (x + y + z).program.function_two(x.program, y.program)
6969
egraph = EGraph()
70-
egraph.register(fn.eval_py_object({"z": 10}))
71-
egraph.run(to_program_ruleset * 100 + program_gen_ruleset * 100)
72-
res = egraph.eval(fn.py_object)
70+
evalled = EvalProgram(fn, {"z": 10})
71+
egraph.register(evalled)
72+
egraph.run((to_program_ruleset | eval_program_rulseset | program_gen_ruleset).saturate())
73+
res = egraph.eval(evalled.py_object)
7374
assert res(1, 2) == 13 # type: ignore[operator]
7475
assert inspect.getsource(res) # type: ignore[arg-type]

0 commit comments

Comments
 (0)