Skip to content

Commit c25c866

Browse files
Add e2e loopnest value test
1 parent 5f44c26 commit c25c866

File tree

10 files changed

+147
-18
lines changed

10 files changed

+147
-18
lines changed

python/egglog/exp/array_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ class Int(Expr, ruleset=array_api_ruleset):
119119
# https://en.wikipedia.org/wiki/Bottom_type
120120
NEVER: ClassVar[Int]
121121

122+
@classmethod
123+
def var(cls, name: StringLike) -> Int: ...
124+
122125
def __init__(self, value: i64Like) -> None: ...
123126

124127
def __invert__(self) -> Int: ...

python/egglog/exp/array_api_loopnest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""
2-
In progress module
2+
Example module to replicate behavior expressed in
33
44
https://gist.github.com/sklam/5e5737137d48d6e5b816d14a90076f1d
5-
65
"""
76

87
# %%

python/egglog/exp/array_api_program_gen.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def int_program(x: Int) -> Program: ...
3838

3939

4040
@array_api_program_gen_ruleset.register
41-
def _int_program(i64_: i64, i: Int, j: Int):
41+
def _int_program(i64_: i64, i: Int, j: Int, s: String):
42+
yield rewrite(int_program(Int.var(s))).to(Program(s, True))
4243
yield rewrite(int_program(Int(i64_))).to(Program(i64_.to_string()))
4344
yield rewrite(int_program(~i)).to(Program("~") + int_program(i))
4445
yield rewrite(bool_program(i < j)).to(Program("(") + int_program(i) + " < " + int_program(j) + ")")
@@ -145,9 +146,14 @@ def _value_program(i: Int, b: Boolean, f: Float, x: NDArray, v1: Value, v2: Valu
145146
yield rewrite(value_program(x.to_value())).to(ndarray_program(x))
146147
yield rewrite(value_program(v1 < v2)).to(Program("(") + value_program(v1) + " < " + value_program(v2) + ")")
147148
yield rewrite(value_program(v1 / v2)).to(Program("(") + value_program(v1) + " / " + value_program(v2) + ")")
149+
yield rewrite(value_program(v1 + v2)).to(Program("(") + value_program(v1) + " + " + value_program(v2) + ")")
150+
yield rewrite(value_program(v1 * v2)).to(Program("(") + value_program(v1) + " * " + value_program(v2) + ")")
148151
yield rewrite(bool_program(v1.to_bool)).to(value_program(v1))
149152
yield rewrite(int_program(v1.to_int)).to(value_program(v1))
150-
yield rewrite(value_program(xs.index(ti))).to(ndarray_program(xs) + "[" + tuple_int_program(ti) + "]")
153+
yield rewrite(value_program(xs.index(ti))).to((ndarray_program(xs) + "[" + tuple_int_program(ti) + "]").assign())
154+
yield rewrite(value_program(v1.sqrt())).to(Program("np.sqrt(") + value_program(v1) + ")")
155+
yield rewrite(value_program(v1.real())).to(Program("np.real(") + value_program(v1) + ")")
156+
yield rewrite(value_program(v1.conj())).to(Program("np.conj(") + value_program(v1) + ")")
151157

152158

153159
@function

python/egglog/exp/program_gen.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ def function_two(self, arg1: ProgramLike, arg2: ProgramLike, name: StringLike =
4545
Returns a new program defining a function with two arguments.
4646
"""
4747

48+
def function_three(
49+
self, arg1: ProgramLike, arg2: ProgramLike, arg3: ProgramLike, name: StringLike = String("__fn")
50+
) -> Program:
51+
"""
52+
Returns a new program defining a function with three arguments.
53+
"""
54+
4855
def expr_to_statement(self) -> Program:
4956
"""
5057
Returns a new program with the expression as a statement and the new expression empty.
@@ -142,10 +149,12 @@ def program_gen_ruleset(
142149
s3: String,
143150
s4: String,
144151
s5: String,
152+
s6: String,
145153
p: Program,
146154
p1: Program,
147155
p2: Program,
148156
p3: Program,
157+
p4: Program,
149158
i: i64,
150159
i2: i64,
151160
b: Bool,
@@ -357,6 +366,7 @@ def program_gen_ruleset(
357366

358367
##
359368
# Function two
369+
##
360370

361371
# When compiling a function, the two args, p2 and p3, should get compiled when we compile p1, and should just be vars.
362372
fn_two = eq(p).to(p1.function_two(p2, p3, s1))
@@ -387,3 +397,35 @@ def program_gen_ruleset(
387397
set_(p.next_sym).to(i),
388398
set_(p.expr).to(s1),
389399
)
400+
401+
##
402+
# Function three
403+
##
404+
405+
fn_three = eq(p).to(p1.function_three(p2, p3, p4, s1))
406+
yield rule(fn_three, p.compile(i)).then(
407+
set_(p2.parent).to(p),
408+
set_(p3.parent).to(p),
409+
set_(p1.parent).to(p),
410+
set_(p4.parent).to(p),
411+
p2.compile(i),
412+
p3.compile(i),
413+
p1.compile(i),
414+
p4.compile(i),
415+
set_(p.is_identifer).to(Bool(True)),
416+
)
417+
yield rule(
418+
fn_three,
419+
p.compile(i),
420+
eq(s2).to(p1.expr),
421+
eq(s3).to(p1.statements),
422+
eq(s4).to(p2.expr),
423+
eq(s5).to(p3.expr),
424+
eq(s6).to(p4.expr),
425+
).then(
426+
set_(p.statements).to(
427+
join("def ", s1, "(", s4, ", ", s5, ", ", s6, "):\n ", s3.replace("\n", "\n "), "return ", s2, "\n")
428+
),
429+
set_(p.next_sym).to(i),
430+
set_(p.expr).to(s1),
431+
)

python/egglog/runtime.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,8 +601,10 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
601601
case RuntimeClass(thunk, tp):
602602
return InitRef(tp.name), thunk()
603603
case RuntimeExpr(decl_thunk, expr_thunk):
604-
if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(expr.callable, ConstantRef):
605-
raise NotImplementedError(f"Can only turn constants into callable refs, not {expr}")
604+
if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(
605+
expr.callable, ConstantRef | ClassVariableRef
606+
):
607+
raise NotImplementedError(f"Can only turn constants or classvars into callable refs, not {expr}")
606608
return expr.callable, decl_thunk()
607609
case _:
608610
raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
def __fn(X, i, j):
2+
_0 = X[(0, 0, i, j, )]
3+
_1 = X[(0, 1, i, j, )]
4+
_2 = X[(1, 0, i, j, )]
5+
_3 = X[(1, 1, i, j, )]
6+
_4 = X[(2, 0, i, j, )]
7+
_5 = X[(2, 1, i, j, )]
8+
return np.sqrt((((((np.real((np.conj(_0) * _0)) + np.real((np.conj(_1) * _1))) + np.real((np.conj(_2) * _2))) + np.real((np.conj(_3) * _3))) + np.real((np.conj(_4) * _4))) + np.real((np.conj(_5) * _5))))
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
_Value_1 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(0), Int(0), Int.var("i"), Int.var("j"))))
2+
_Value_2 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(0), Int(1), Int.var("i"), Int.var("j"))))
3+
_Value_3 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(1), Int(0), Int.var("i"), Int.var("j"))))
4+
_Value_4 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(1), Int(1), Int.var("i"), Int.var("j"))))
5+
_Value_5 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(2), Int(0), Int.var("i"), Int.var("j"))))
6+
_Value_6 = NDArray.var("X").index(TupleInt.from_vec(Vec[Int](Int(2), Int(1), Int.var("i"), Int.var("j"))))
7+
(
8+
(
9+
((((_Value_1.conj() * _Value_1).real() + (_Value_2.conj() * _Value_2).real()) + (_Value_3.conj() * _Value_3).real()) + (_Value_4.conj() * _Value_4).real())
10+
+ (_Value_5.conj() * _Value_5).real()
11+
)
12+
+ (_Value_6.conj() * _Value_6).real()
13+
).sqrt()
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
def my_fn(x, y, z):
2+
_0 = -x
3+
assert _0 > 0
4+
_1 = _0 + y
5+
_2 = _1 + z
6+
_3 = _2 + 2
7+
_4 = _3 + _2
8+
return _4

python/tests/test_array_api.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -225,29 +225,33 @@ def linalg_norm_v2(X: NDArrayLike, axis: TupleIntLike) -> NDArray:
225225
)
226226

227227

228-
def linalg_val(linalg_fn: Callable[[NDArray, TupleIntLike], NDArray]) -> NDArray:
229-
X = NDArray.var("X")
228+
def linalg_val(X: NDArray, linalg_fn: Callable[[NDArray, TupleIntLike], NDArray]) -> NDArray:
230229
assume_shape(X, (3, 2, 3, 4))
231230
return linalg_fn(X, (0, 1))
232231

233232

234233
class TestLoopNest:
235234
@pytest.mark.parametrize("linalg_fn", [linalg_norm, linalg_norm_v2])
236235
def test_shape(self, linalg_fn):
237-
check_eq(linalg_val(linalg_fn).shape, TupleInt.from_vec((3, 4)), array_api_schedule)
236+
X = np.random.random((3, 2, 3, 4))
237+
expect = np.linalg.norm(X, axis=(0, 1))
238+
assert expect.shape == (3, 4)
239+
240+
check_eq(linalg_val(constant("X", NDArray), linalg_fn).shape, TupleInt.from_vec((3, 4)), array_api_schedule)
238241

239242
@pytest.mark.parametrize("linalg_fn", [linalg_norm, linalg_norm_v2])
240-
def test_index(self, linalg_fn):
243+
def test_abstract_index(self, linalg_fn):
241244
i = constant("i", Int)
242245
j = constant("j", Int)
243-
idxed = linalg_val(linalg_fn).index((i, j))
244-
_NDArray_1 = NDArray.var("X")
245-
_Value_1 = _NDArray_1.index(TupleInt.from_vec(Vec[Int](Int(0), Int(0), i, j)))
246-
_Value_2 = _NDArray_1.index(TupleInt.from_vec(Vec[Int](Int(0), Int(1), i, j)))
247-
_Value_3 = _NDArray_1.index(TupleInt.from_vec(Vec[Int](Int(1), Int(0), i, j)))
248-
_Value_4 = _NDArray_1.index(TupleInt.from_vec(Vec[Int](Int(1), Int(1), i, j)))
249-
_Value_5 = _NDArray_1.index(TupleInt.from_vec(Vec[Int](Int(2), Int(0), i, j)))
250-
_Value_6 = _NDArray_1.index(TupleInt.from_vec(Vec[Int](Int(2), Int(1), i, j)))
246+
X = constant("X", NDArray)
247+
idxed = linalg_val(X, linalg_fn).index((i, j))
248+
249+
_Value_1 = X.index(TupleInt.from_vec(Vec[Int](Int(0), Int(0), i, j)))
250+
_Value_2 = X.index(TupleInt.from_vec(Vec[Int](Int(0), Int(1), i, j)))
251+
_Value_3 = X.index(TupleInt.from_vec(Vec[Int](Int(1), Int(0), i, j)))
252+
_Value_4 = X.index(TupleInt.from_vec(Vec[Int](Int(1), Int(1), i, j)))
253+
_Value_5 = X.index(TupleInt.from_vec(Vec[Int](Int(2), Int(0), i, j)))
254+
_Value_6 = X.index(TupleInt.from_vec(Vec[Int](Int(2), Int(1), i, j)))
251255
res = (
252256
(
253257
(
@@ -263,6 +267,36 @@ def test_index(self, linalg_fn):
263267
).sqrt()
264268
check_eq(idxed, res, array_api_schedule)
265269

270+
def test_index_codegen(self, snapshot_py):
271+
X = NDArray.var("X")
272+
i = Int.var("i")
273+
j = Int.var("j")
274+
idxed = linalg_val(X, linalg_norm_v2).index((i, j))
275+
simplified_index = simplify(idxed, array_api_schedule)
276+
assert str(simplified_index) == snapshot_py(name="expr")
277+
278+
res = EvalProgram(
279+
value_program(simplified_index).function_three(ndarray_program(X), int_program(i), int_program(j)),
280+
{"np": np},
281+
)
282+
egraph = EGraph()
283+
egraph.register(res)
284+
egraph.run(array_api_program_gen_schedule)
285+
print(
286+
egraph.extract(
287+
value_program(simplified_index).function_three(ndarray_program(X), int_program(i), int_program(j))
288+
)
289+
)
290+
# egraph.display(split_primitive_outputs=True, n_inline_leaves=3, split_functions=[TupleInt.EMPTY, TupleInt.append, Int])
291+
assert egraph.eval(res.statements) == snapshot_py(name="code")
292+
293+
fn_value = egraph.eval(res.py_object)
294+
X = np.random.random((3, 2, 3, 4))
295+
expect = np.linalg.norm(X, axis=(0, 1))
296+
297+
for idxs in np.ndindex(*expect.shape):
298+
assert np.allclose(fn_value(X, *idxs), expect[idxs], rtol=1e-03)
299+
266300

267301
# This test happens in different steps. Each will be benchmarked and saved as a snapshot.
268302
# The next step will load the old snapshot and run their test on it.

python/tests/test_program_gen.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,20 @@ def test_to_string(snapshot_py) -> None:
6161
assert egraph.eval(fn.statements) == snapshot_py
6262

6363

64+
def test_to_string_function_three(snapshot_py) -> None:
65+
first = assume_pos(-Math.var("x")) + Math.var("y") + Math.var("z")
66+
fn = (first + Math(2) + first).program.function_three(
67+
Math.var("x").program, Math.var("y").program, Math.var("z").program, "my_fn"
68+
)
69+
egraph = EGraph()
70+
egraph.register(fn)
71+
egraph.register(fn.compile())
72+
egraph.run((to_program_ruleset | program_gen_ruleset).saturate())
73+
# egraph.display(n_inline_leaves=1)
74+
assert egraph.eval(fn.expr) == "my_fn"
75+
assert egraph.eval(fn.statements) == snapshot_py
76+
77+
6478
def test_py_object():
6579
x = Math.var("x")
6680
y = Math.var("y")

0 commit comments

Comments
 (0)