Skip to content

Commit 830adf2

Browse files
authored
add wrappers for ilist (#215)
1 parent ce08ce1 commit 830adf2

File tree

3 files changed

+117
-0
lines changed

3 files changed

+117
-0
lines changed

src/kirin/dialects/ilist/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
from .passes import IListDesugar as IListDesugar
2525
from .runtime import IList as IList
2626
from ._dialect import dialect as dialect
27+
from ._wrapper import map as map, scan as scan, foldl as foldl, foldr as foldr
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Any, TypeVar, Iterable
2+
3+
from kirin import ir
4+
from kirin.lowering import wraps
5+
6+
from . import stmts
7+
from .runtime import IList
8+
9+
T_Elem = TypeVar("T_Elem")
10+
T_Out = TypeVar("T_Out")
11+
T_Result = TypeVar("T_Result")
12+
13+
14+
@wraps(stmts.Map)
15+
def map(fn: ir.Method[[T_Elem], T_Out], collection: Iterable) -> IList | list: ...
16+
17+
18+
@wraps(stmts.Foldr)
19+
def foldr(
20+
fn: ir.Method[[T_Elem, T_Out], T_Out], collection: Iterable, init: T_Out
21+
) -> T_Out: ...
22+
23+
24+
@wraps(stmts.Foldl)
25+
def foldl(
26+
fn: ir.Method[[T_Out, T_Elem], T_Out], collection: Iterable, init: T_Out
27+
) -> T_Out: ...
28+
29+
30+
@wraps(stmts.Scan)
31+
def scan(
32+
fn: ir.Method[[T_Out, T_Elem], tuple[T_Out, T_Result]],
33+
collection: Iterable,
34+
init: T_Out,
35+
) -> tuple[T_Out, IList[T_Result, Any]]: ...
36+
37+
38+
@wraps(stmts.ForEach)
39+
def for_each(fn: ir.Method[[T_Elem], None], collection: Iterable) -> None: ...
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from kirin.prelude import basic
2+
from kirin.dialects import ilist
3+
4+
5+
def test_map_wrapper():
6+
7+
@basic
8+
def add1(x: int):
9+
return x + 1
10+
11+
@basic
12+
def map_wrap():
13+
return ilist.map(add1, range(5))
14+
15+
out = map_wrap()
16+
assert isinstance(out, ilist.IList)
17+
assert out.data == [1, 2, 3, 4, 5]
18+
19+
20+
def test_foldr_wrapper():
21+
22+
@basic
23+
def add_fold(x: int, out: int):
24+
return out + x
25+
26+
@basic
27+
def map_foldr():
28+
return ilist.foldr(add_fold, range(5), init=10)
29+
30+
out = map_foldr()
31+
assert isinstance(out, int)
32+
assert out == 10 + 0 + 1 + 2 + 3 + 4
33+
34+
35+
def test_foldl_wrapper():
36+
37+
@basic
38+
def add_fold2(out: int, x: int):
39+
return out + x
40+
41+
@basic
42+
def map_foldl():
43+
return ilist.foldr(add_fold2, range(5), init=10)
44+
45+
out = map_foldl()
46+
assert isinstance(out, int)
47+
assert out == 10 + 0 + 1 + 2 + 3 + 4
48+
49+
50+
def test_scan_wrapper():
51+
52+
@basic
53+
def add_scan(out: int, x: int):
54+
return out + 1, out + x
55+
56+
@basic
57+
def scan_wrap():
58+
return ilist.scan(add_scan, range(5), init=10)
59+
60+
out = scan_wrap()
61+
assert isinstance(out, tuple)
62+
assert len(out) == 2
63+
64+
res = out[0]
65+
out_list = out[1]
66+
67+
assert isinstance(res, int)
68+
assert res == 10 + 1 * 5
69+
70+
assert isinstance(out_list, ilist.IList)
71+
assert out_list.data == [
72+
10 + 0,
73+
10 + 1 + 1,
74+
10 + 1 + 1 + 2,
75+
10 + 1 + 1 + 1 + 3,
76+
10 + 1 + 1 + 1 + 1 + 4,
77+
]

0 commit comments

Comments
 (0)