Skip to content

Commit 8b17580

Browse files
Merge pull request #242 from egraphs-good/fix-loopnest
Working loopnest example
2 parents 9bb83d4 + 5c73999 commit 8b17580

File tree

7 files changed

+64
-82
lines changed

7 files changed

+64
-82
lines changed

docs/changelog.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ _This project uses semantic versioning_
77
- Fix pretty printing of lambda functions
88
- Add support for subsuming rewrite generated by default function and method definitions
99
- Add better error message when using @function in class (thanks @shinawy)
10+
- Add error method if `@method` decorator is in wrong place
11+
- Subsumes lambda functions after replacing
12+
- Add working loopnest test
1013

1114
## 8.0.1 (2024-10-24)
1215

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ filterwarnings = [
241241
"error",
242242
"ignore::numba.core.errors.NumbaPerformanceWarning",
243243
"ignore::pytest_benchmark.logger.PytestBenchmarkWarning",
244+
# https://github.com/manzt/anywidget/blob/d38bb3f5f9cfc7e49e2ff1aa1ba994d66327cb02/pyproject.toml#L120
245+
"ignore:Deprecated in traitlets 4.1, use the instance .metadata:DeprecationWarning",
244246
]
245247

246248
[tool.coverage.report]

python/egglog/builtins.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,8 @@ def _convert_function(a: FunctionType) -> UnstableFn:
532532
transformed_fn = functionalize(a, value_to_annotation)
533533
assert isinstance(transformed_fn, partial)
534534
return UnstableFn(
535-
function(ruleset=get_current_ruleset(), use_body_as_name=True)(transformed_fn.func), *transformed_fn.args
535+
function(ruleset=get_current_ruleset(), use_body_as_name=True, subsume=True)(transformed_fn.func),
536+
*transformed_fn.args,
536537
)
537538

538539

python/egglog/egraph.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,10 @@ def _generate_class_decls( # noqa: C901,PLR0912
570570
fn = fn.fget
571571
case _:
572572
ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name)
573+
if isinstance(fn, _WrappedMethod):
574+
msg = f"{cls_name}.{method_name} Add the @method(...) decorator above @classmethod or @property"
575+
576+
raise ValueError(msg) # noqa: TRY004
573577
special_function_name: SpecialFunctions | None = (
574578
"fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None
575579
)
@@ -1373,10 +1377,14 @@ def saturate(
13731377
"""
13741378
Saturate the egraph, running the given schedule until the egraph is saturated.
13751379
It serializes the egraph at each step and returns a widget to visualize the egraph.
1380+
1381+
If an `expr` is passed, it's also extracted after each run and printed
13761382
"""
13771383
from .visualizer_widget import VisualizerWidget
13781384

13791385
def to_json() -> str:
1386+
if expr is not None:
1387+
print(self.extract(expr), "\n")
13801388
return self._serialize(**kwargs).to_json()
13811389

13821390
egraphs = [to_json()]

python/egglog/exp/array_api.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def single(cls, i: Int) -> TupleInt:
278278
return TupleInt(Int(1), lambda _: i)
279279

280280
@classmethod
281-
def range(cls, stop: Int) -> TupleInt:
281+
def range(cls, stop: IntLike) -> TupleInt:
282282
return TupleInt(stop, lambda i: i)
283283

284284
@classmethod
@@ -346,7 +346,6 @@ def _tuple_int(
346346
ti: TupleInt,
347347
ti2: TupleInt,
348348
):
349-
remaining = TupleInt(k - 1, lambda i: idx_fn(i + 1)).filter(filter_f)
350349
return [
351350
rewrite(TupleInt(i, idx_fn).length()).to(i),
352351
rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(i2)),
@@ -367,7 +366,11 @@ def _tuple_int(
367366
# filter TODO: could be written as fold w/ generic types
368367
rewrite(TupleInt(0, idx_fn).filter(filter_f)).to(TupleInt(0, idx_fn)),
369368
rewrite(TupleInt(Int(k), idx_fn).filter(filter_f)).to(
370-
TupleInt.if_(filter_f(value := idx_fn(Int(k))), TupleInt.single(value) + remaining, remaining),
369+
TupleInt.if_(
370+
filter_f(value := idx_fn(Int(k - 1))),
371+
(remaining := TupleInt(k - 1, idx_fn).filter(filter_f)) + TupleInt.single(value),
372+
remaining,
373+
),
371374
ne(k).to(i64(0)),
372375
),
373376
# Empty
@@ -386,13 +389,13 @@ def var(cls, name: StringLike) -> TupleTupleInt: ...
386389

387390
def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None: ...
388391

389-
@classmethod
390392
@method(subsume=True)
393+
@classmethod
391394
def single(cls, i: TupleInt) -> TupleTupleInt:
392395
return TupleTupleInt(Int(1), lambda _: i)
393396

394-
@classmethod
395397
@method(subsume=True)
398+
@classmethod
396399
def from_vec(cls, vec: Vec[Int]) -> TupleInt:
397400
return TupleInt(vec.length(), partial(index_vec_int, vec))
398401

python/egglog/exp/array_api_loopnest.py

Lines changed: 2 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from egglog import *
1313
from egglog.exp.array_api import *
1414

15+
__all__ = ["LoopNestAPI", "OptionalLoopNestAPI", "ShapeAPI"]
16+
1517

1618
class ShapeAPI(Expr):
1719
def __init__(self, dims: TupleIntLike) -> None: ...
@@ -105,76 +107,3 @@ def _loopnest_api_ruleset(
105107
yield rewrite(lna.indices, subsume=True).to(
106108
tuple_tuple_int_product(tuple_int_map_tuple_int(lna.get_dims(), TupleInt.range))
107109
)
108-
109-
110-
@function(ruleset=array_api_ruleset, subsume=True)
111-
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
112-
# peel off the outer shape for result array
113-
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
114-
# get only the inner shape for reduction
115-
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()
116-
117-
return NDArray(
118-
outshape,
119-
X.dtype,
120-
lambda k: sqrt(
121-
LoopNestAPI.from_tuple(reduce_axis)
122-
.unwrap()
123-
.fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
124-
).to_value(),
125-
)
126-
127-
128-
# %%
129-
# egraph = EGraph(save_egglog_string=True)
130-
131-
# egraph.register(val.shape)
132-
# egraph.run(array_api_ruleset.saturate())
133-
# egraph.extract_multiple(val.shape, 10)
134-
135-
# %%
136-
137-
X = NDArray.var("X")
138-
assume_shape(X, (3, 2, 3, 4))
139-
val = linalg_norm(X, (0, 1))
140-
egraph = EGraph()
141-
x = egraph.let("x", val.shape[2])
142-
# egraph.display(n_inline_leaves=0)
143-
# egraph.extract(x)
144-
# egraph.saturate(array_api_ruleset, expr=x, split_functions=[Int, TRUE, FALSE], n_inline_leaves=0)
145-
# egraph.run(array_api_ruleset.saturate())
146-
# egraph.extract(x)
147-
# egraph.display()
148-
149-
150-
# %%
151-
152-
# x = xs[-2]
153-
# # %%
154-
# decls = x.__egg_decls__
155-
# # RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1])
156-
157-
# # %%
158-
# # x.__egg_typed_expr__.expr.args[1].expr.args[1] # %%
159-
160-
# # %%
161-
# # egraph.extract(RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1]))
162-
163-
164-
# from egglog import pretty
165-
166-
# decl = (
167-
# x.__egg_typed_expr__.expr.args[1]
168-
# .expr.args[2]
169-
# .expr.args[0]
170-
# .expr.args[1]
171-
# .expr.call.args[0]
172-
# .expr.call.args[0]
173-
# .expr.call.args[0]
174-
# )
175-
176-
# # pprint.pprint(decl)
177-
178-
# print(pretty.pretty_decl(decls, decl.expr))
179-
180-
# # %%

python/tests/test_array_api.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from sklearn import config_context, datasets
99
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
1010

11+
from egglog.egraph import set_current_ruleset
1112
from egglog.exp.array_api import *
1213
from egglog.exp.array_api_jit import jit
14+
from egglog.exp.array_api_loopnest import *
1315
from egglog.exp.array_api_numba import array_api_numba_schedule
1416
from egglog.exp.array_api_program_gen import *
1517

@@ -68,6 +70,41 @@ def test_reshape_vec_noop():
6870
egraph.check(eq(res).to(x))
6971

7072

73+
def test_filter():
74+
with set_current_ruleset(array_api_ruleset):
75+
x = TupleInt.range(5).filter(lambda i: i < 2).length()
76+
check_eq(x, Int(2), array_api_schedule)
77+
78+
79+
@function(ruleset=array_api_ruleset, subsume=True)
80+
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
81+
# peel off the outer shape for result array
82+
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
83+
# get only the inner shape for reduction
84+
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()
85+
86+
return NDArray(
87+
outshape,
88+
X.dtype,
89+
lambda k: sqrt(
90+
LoopNestAPI.from_tuple(reduce_axis)
91+
.unwrap()
92+
.fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
93+
).to_value(),
94+
)
95+
96+
97+
class TestLoopNest:
98+
def test_shape(self):
99+
X = NDArray.var("X")
100+
assume_shape(X, (3, 2, 3, 4))
101+
val = linalg_norm(X, (0, 1))
102+
103+
check_eq(val.shape.length(), Int(2), array_api_schedule)
104+
check_eq(val.shape[0], Int(3), array_api_schedule)
105+
check_eq(val.shape[1], Int(4), array_api_schedule)
106+
107+
71108
# This test happens in different steps. Each will be benchmarked and saved as a snapshot.
72109
# The next step will load the old snapshot and run their test on it.
73110

@@ -80,7 +117,6 @@ def run_lda(x, y):
80117

81118
iris = datasets.load_iris()
82119
X_np, y_np = (iris.data, iris.target)
83-
res_np = run_lda(X_np, y_np)
84120

85121

86122
def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
@@ -165,7 +201,7 @@ def test_source_optimized(self, snapshot_py, benchmark):
165201
optimized_expr = simplify_lda(egraph, expr)
166202
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
167203
py_object = benchmark(load_source, fn_program, egraph)
168-
assert np.allclose(py_object(X_np, y_np), res_np)
204+
assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np))
169205
assert egraph.eval(fn_program.statements) == snapshot_py
170206

171207
@pytest.mark.parametrize(
@@ -180,7 +216,7 @@ def test_source_optimized(self, snapshot_py, benchmark):
180216
)
181217
def test_execution(self, fn, benchmark):
182218
# warmup once for numba
183-
assert np.allclose(res_np, fn(X_np, y_np))
219+
assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np))
184220
benchmark(fn, X_np, y_np)
185221

186222

0 commit comments

Comments
 (0)