Skip to content

Commit 5c73999

Browse files
Add working loopnest example
1 parent 7848e74 commit 5c73999

File tree

3 files changed

+42
-76
lines changed

3 files changed

+42
-76
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ _This project uses semantic versioning_
99
- Add better error message when using @function in class (thanks @shinawy)
1010
- Add error method if `@method` decorator is in wrong place
1111
- Subsumes lambda functions after replacing
12+
- Add working loopnest test
1213

1314
## 8.0.1 (2024-10-24)
1415

python/egglog/exp/array_api_loopnest.py

Lines changed: 2 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from egglog import *
1313
from egglog.exp.array_api import *
1414

15+
__all__ = ["LoopNestAPI", "OptionalLoopNestAPI", "ShapeAPI"]
16+
1517

1618
class ShapeAPI(Expr):
1719
def __init__(self, dims: TupleIntLike) -> None: ...
@@ -105,76 +107,3 @@ def _loopnest_api_ruleset(
105107
yield rewrite(lna.indices, subsume=True).to(
106108
tuple_tuple_int_product(tuple_int_map_tuple_int(lna.get_dims(), TupleInt.range))
107109
)
108-
109-
110-
@function(ruleset=array_api_ruleset, subsume=True)
111-
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
112-
# peel off the outer shape for result array
113-
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
114-
# get only the inner shape for reduction
115-
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()
116-
117-
return NDArray(
118-
outshape,
119-
X.dtype,
120-
lambda k: sqrt(
121-
LoopNestAPI.from_tuple(reduce_axis)
122-
.unwrap()
123-
.fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
124-
).to_value(),
125-
)
126-
127-
128-
# %%
129-
# egraph = EGraph(save_egglog_string=True)
130-
131-
# egraph.register(val.shape)
132-
# egraph.run(array_api_ruleset.saturate())
133-
# egraph.extract_multiple(val.shape, 10)
134-
135-
# %%
136-
137-
X = NDArray.var("X")
138-
assume_shape(X, (3, 2, 3, 4))
139-
val = linalg_norm(X, (0, 1))
140-
egraph = EGraph()
141-
x = egraph.let("x", val.shape[2])
142-
# egraph.display(n_inline_leaves=0)
143-
# egraph.extract(x)
144-
# egraph.saturate(array_api_ruleset, expr=x, split_functions=[Int, TRUE, FALSE], n_inline_leaves=0)
145-
# egraph.run(array_api_ruleset.saturate())
146-
# egraph.extract(x)
147-
# egraph.display()
148-
149-
150-
# %%
151-
152-
# x = xs[-2]
153-
# # %%
154-
# decls = x.__egg_decls__
155-
# # RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1])
156-
157-
# # %%
158-
# # x.__egg_typed_expr__.expr.args[1].expr.args[1] # %%
159-
160-
# # %%
161-
# # egraph.extract(RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1]))
162-
163-
164-
# from egglog import pretty
165-
166-
# decl = (
167-
# x.__egg_typed_expr__.expr.args[1]
168-
# .expr.args[2]
169-
# .expr.args[0]
170-
# .expr.args[1]
171-
# .expr.call.args[0]
172-
# .expr.call.args[0]
173-
# .expr.call.args[0]
174-
# )
175-
176-
# # pprint.pprint(decl)
177-
178-
# print(pretty.pretty_decl(decls, decl.expr))
179-
180-
# # %%

python/tests/test_array_api.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from sklearn import config_context, datasets
99
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
1010

11+
from egglog.egraph import set_current_ruleset
1112
from egglog.exp.array_api import *
1213
from egglog.exp.array_api_jit import jit
14+
from egglog.exp.array_api_loopnest import *
1315
from egglog.exp.array_api_numba import array_api_numba_schedule
1416
from egglog.exp.array_api_program_gen import *
1517

@@ -68,6 +70,41 @@ def test_reshape_vec_noop():
6870
egraph.check(eq(res).to(x))
6971

7072

73+
def test_filter():
74+
with set_current_ruleset(array_api_ruleset):
75+
x = TupleInt.range(5).filter(lambda i: i < 2).length()
76+
check_eq(x, Int(2), array_api_schedule)
77+
78+
79+
@function(ruleset=array_api_ruleset, subsume=True)
80+
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
81+
# peel off the outer shape for result array
82+
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
83+
# get only the inner shape for reduction
84+
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()
85+
86+
return NDArray(
87+
outshape,
88+
X.dtype,
89+
lambda k: sqrt(
90+
LoopNestAPI.from_tuple(reduce_axis)
91+
.unwrap()
92+
.fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
93+
).to_value(),
94+
)
95+
96+
97+
class TestLoopNest:
98+
def test_shape(self):
99+
X = NDArray.var("X")
100+
assume_shape(X, (3, 2, 3, 4))
101+
val = linalg_norm(X, (0, 1))
102+
103+
check_eq(val.shape.length(), Int(2), array_api_schedule)
104+
check_eq(val.shape[0], Int(3), array_api_schedule)
105+
check_eq(val.shape[1], Int(4), array_api_schedule)
106+
107+
71108
# This test happens in different steps. Each will be benchmarked and saved as a snapshot.
72109
# The next step will load the old snapshot and run their test on it.
73110

@@ -80,7 +117,6 @@ def run_lda(x, y):
80117

81118
iris = datasets.load_iris()
82119
X_np, y_np = (iris.data, iris.target)
83-
res_np = run_lda(X_np, y_np)
84120

85121

86122
def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
@@ -165,7 +201,7 @@ def test_source_optimized(self, snapshot_py, benchmark):
165201
optimized_expr = simplify_lda(egraph, expr)
166202
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
167203
py_object = benchmark(load_source, fn_program, egraph)
168-
assert np.allclose(py_object(X_np, y_np), res_np)
204+
assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np))
169205
assert egraph.eval(fn_program.statements) == snapshot_py
170206

171207
@pytest.mark.parametrize(
@@ -180,7 +216,7 @@ def test_source_optimized(self, snapshot_py, benchmark):
180216
)
181217
def test_execution(self, fn, benchmark):
182218
# warmup once for numba
183-
assert np.allclose(res_np, fn(X_np, y_np))
219+
assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np))
184220
benchmark(fn, X_np, y_np)
185221

186222

0 commit comments

Comments
 (0)