Skip to content

Commit f3b96c6

Browse files
Start adding support for getting actual value of indexed array.
1 parent 5c73999 commit f3b96c6

File tree

3 files changed

+87
-17
lines changed

3 files changed

+87
-17
lines changed

python/egglog/exp/array_api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def __len__(self) -> int:
303303
def __iter__(self) -> Iterator[Int]:
304304
return iter(self[i] for i in range(len(self)))
305305

306+
# TODO: Rename to reduce to match Python? And re-order?
306307
def fold(self, init: Int, f: Callable[[Int, Int], Int]) -> Int: ...
307308

308309
def fold_boolean(self, init: Boolean, f: Callable[[Boolean, Int], Boolean]) -> Boolean: ...
@@ -861,7 +862,8 @@ def __rxor__(self, other: NDArray) -> NDArray: ...
861862
def __ror__(self, other: NDArray) -> NDArray: ...
862863

863864
@classmethod
864-
def scalar(cls, value: Value) -> NDArray: ...
865+
def scalar(cls, value: Value) -> NDArray:
866+
return NDArray(TupleInt.EMPTY, value.dtype, lambda _: value)
865867

866868
def to_value(self) -> Value: ...
867869

@@ -900,10 +902,12 @@ def _ndarray(
900902
shape: TupleInt,
901903
dtype: DType,
902904
idx_fn: Callable[[TupleInt], Value],
905+
idx: TupleInt,
903906
):
904907
return [
905908
rewrite(NDArray(shape, dtype, idx_fn).shape).to(shape),
906909
rewrite(NDArray(shape, dtype, idx_fn).dtype).to(dtype),
910+
rewrite(NDArray(shape, dtype, idx_fn).index(idx), subsume=True).to(idx_fn(idx)),
907911
rewrite(x.ndim).to(x.shape.length()),
908912
# rewrite(NDArray.scalar(Value.bool(b)).to_bool()).to(b),
909913
# Converting to a value requires a scalar bool value

python/egglog/exp/array_api_loopnest.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,63 @@ def tuple_tuple_int_reduce_ndarray(
6969
) -> NDArray: ...
7070

7171

72+
@array_api_ruleset.register
73+
def _tuple_tuple_int_reduce_ndarray(
74+
idx_fn: Callable[[Int], TupleInt], f: Callable[[NDArray, TupleInt], NDArray], i: NDArray, k: i64
75+
):
76+
yield rewrite(tuple_tuple_int_reduce_ndarray(TupleTupleInt(0, idx_fn), f, i)).to(i)
77+
yield rewrite(tuple_tuple_int_reduce_ndarray(TupleTupleInt(Int(k), idx_fn), f, i), subsume=True).to(
78+
f(
79+
tuple_tuple_int_reduce_ndarray(TupleTupleInt(k - 1, lambda i: idx_fn(i + 1)), f, i),
80+
idx_fn(Int(0)),
81+
),
82+
ne(k).to(i64(0)),
83+
)
84+
85+
7286
@function
7387
def tuple_int_map_tuple_int(xs: TupleInt, fn: Callable[[Int], TupleInt]) -> TupleTupleInt: ...
7488

7589

90+
@array_api_ruleset.register
91+
def _tuple_int_map_tuple_int(length: Int, fn: Callable[[Int], TupleInt], idx_fn: Callable[[Int], Int]):
92+
yield rewrite(
93+
tuple_int_map_tuple_int(
94+
TupleInt(length, idx_fn),
95+
fn,
96+
)
97+
).to(TupleTupleInt(length, lambda i: fn(idx_fn(i))))
98+
99+
76100
@function
77-
def tuple_tuple_int_product(xs: TupleTupleInt) -> TupleTupleInt: ...
101+
def tuple_tuple_int_map_int(xs: TupleTupleInt, fn: Callable[[TupleInt], Int]) -> TupleInt: ...
102+
103+
104+
@array_api_ruleset.register
105+
def _tuple_tuple_int_map_int(length: Int, fn: Callable[[TupleInt], Int], idx_fn: Callable[[Int], TupleInt]):
106+
yield rewrite(
107+
tuple_tuple_int_map_int(
108+
TupleTupleInt(length, idx_fn),
109+
fn,
110+
)
111+
).to(TupleInt(length, lambda i: fn(idx_fn(i))))
112+
113+
114+
@function
115+
def tuple_tuple_int_product_index(xs: TupleTupleInt, i: Int) -> TupleInt: ...
116+
117+
118+
@function(subsume=True, ruleset=array_api_ruleset)
119+
def tuple_tuple_int_product(xs: TupleTupleInt) -> TupleTupleInt:
120+
"""
121+
Cartesian product of inputs
122+
123+
https://docs.python.org/3/library/itertools.html#itertools.product
124+
"""
125+
# length is product of lengths
126+
length = tuple_tuple_int_map_int(xs, lambda x: x.length()).fold(Int(1), lambda x, y: x * y)
127+
128+
return TupleTupleInt(length, partial(tuple_tuple_int_product_index, xs))
78129

79130

80131
@array_api_ruleset.register
@@ -107,3 +158,5 @@ def _loopnest_api_ruleset(
107158
yield rewrite(lna.indices, subsume=True).to(
108159
tuple_tuple_int_product(tuple_int_map_tuple_int(lna.get_dims(), TupleInt.range))
109160
)
161+
# unwrap
162+
yield rewrite(OptionalLoopNestAPI(lna).unwrap()).to(lna)

python/tests/test_array_api.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,19 @@ def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
9494
)
9595

9696

97+
X = NDArray.var("X")
98+
assume_shape(X, (3, 2, 3, 4))
99+
val = linalg_norm(X, (0, 1))
100+
i = constant("i", Int)
101+
j = constant("j", Int)
102+
idxed = val.index((i, j))
103+
104+
egraph = EGraph()
105+
egraph.register(idxed)
106+
egraph.run(array_api_schedule)
107+
print(egraph.extract(idxed))
108+
109+
97110
class TestLoopNest:
98111
def test_shape(self):
99112
X = NDArray.var("X")
@@ -222,18 +235,18 @@ def test_execution(self, fn, benchmark):
222235

223236
# if calling as script, print out egglog source for test
224237
# similar to jit, but don't include pyobject parts so it works in vanilla egglog
225-
if __name__ == "__main__":
226-
print("Generating egglog source for test")
227-
egraph = EGraph(save_egglog_string=True)
228-
X_ = NDArray.var("X")
229-
y_ = NDArray.var("y")
230-
with egraph:
231-
expr = lda(X_, y_)
232-
optimized_expr = egraph.simplify(expr, array_api_numba_schedule)
233-
fn_program = ndarray_function_two_program(optimized_expr, X_, y_)
234-
egraph.register(fn_program.compile())
235-
egraph.run(array_api_program_gen_ruleset.saturate() + program_gen_ruleset.saturate())
236-
egraph.extract(fn_program.statements)
237-
name = "python.egg"
238-
print("Saving to", name)
239-
Path(name).write_text(egraph.as_egglog_string)
238+
# if __name__ == "__main__":
239+
# print("Generating egglog source for test")
240+
# egraph = EGraph(save_egglog_string=True)
241+
# X_ = NDArray.var("X")
242+
# y_ = NDArray.var("y")
243+
# with egraph:
244+
# expr = lda(X_, y_)
245+
# optimized_expr = egraph.simplify(expr, array_api_numba_schedule)
246+
# fn_program = ndarray_function_two_program(optimized_expr, X_, y_)
247+
# egraph.register(fn_program.compile())
248+
# egraph.run(array_api_program_gen_ruleset.saturate() + program_gen_ruleset.saturate())
249+
# egraph.extract(fn_program.statements)
250+
# name = "python.egg"
251+
# print("Saving to", name)
252+
# Path(name).write_text(egraph.as_egglog_string)

0 commit comments

Comments
 (0)