Skip to content

Commit 3104f49

Browse files
Roger-luodavid-plcduck
authored
Roger/backport ilist sorted (#497)
cc: @cduck --------- Co-authored-by: David Plankensteiner <[email protected]> Co-authored-by: Casey Duckering <[email protected]>
1 parent 6dbcb97 commit 3104f49

File tree

6 files changed

+142
-2
lines changed

6 files changed

+142
-2
lines changed

src/kirin/dialects/ilist/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,6 @@
3131
foldl as foldl,
3232
foldr as foldr,
3333
range as range,
34+
sorted as sorted,
3435
for_each as for_each,
3536
)

src/kirin/dialects/ilist/_wrapper.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,36 @@ def for_each(
6565
fn: typing.Callable[[ElemT], typing.Any],
6666
collection: IList[ElemT, LenT] | list[ElemT],
6767
) -> None: ...
68+
69+
70+
@typing.overload
71+
def sorted(collection: IList[ElemT, LenT] | list[ElemT]) -> IList[ElemT, LenT]: ...
72+
73+
74+
@typing.overload
75+
def sorted(
76+
collection: IList[ElemT, LenT] | list[ElemT], reverse: bool
77+
) -> IList[ElemT, LenT]: ...
78+
79+
80+
@typing.overload
81+
def sorted(
82+
collection: IList[ElemT, LenT] | list[ElemT],
83+
key: typing.Callable[[ElemT], OutElemT],
84+
) -> IList[ElemT, LenT]: ...
85+
86+
87+
@typing.overload
88+
def sorted(
89+
collection: IList[ElemT, LenT] | list[ElemT],
90+
key: typing.Callable[[ElemT], OutElemT],
91+
reverse: bool,
92+
) -> IList[ElemT, LenT]: ...
93+
94+
95+
@lowering.wraps(stmts.Sorted)
96+
def sorted(
97+
collection: IList[ElemT, LenT] | list[ElemT],
98+
key: typing.Callable[[ElemT], OutElemT],
99+
reverse: bool,
100+
) -> IList[ElemT, LenT]: ...

src/kirin/dialects/ilist/interp.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from kirin.dialects.py.len import Len
44
from kirin.dialects.py.binop import Add
55

6-
from .stmts import Map, New, Push, Scan, Foldl, Foldr, Range, ForEach
6+
from .stmts import Map, New, Push, Scan, Foldl, Foldr, Range, Sorted, ForEach
77
from .runtime import IList
88
from ._dialect import dialect
99

@@ -91,4 +91,12 @@ def for_each(self, interp: Interpreter, frame: Frame, stmt: ForEach):
9191
for elem in coll.data:
9292
# NOTE: assume fn has been type checked
9393
interp.run_method(fn, (elem,))
94-
return (None,)
94+
return
95+
96+
@impl(Sorted)
97+
def sorted(self, inter: Interpreter, frame: Frame, stmt: Sorted):
98+
key = frame.get(stmt.key)
99+
reverse: bool = frame.get(stmt.reverse)
100+
coll: IList = frame.get(stmt.collection)
101+
102+
return (IList(data=sorted(coll.data, key=key, reverse=reverse)),)

src/kirin/dialects/ilist/lowering.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import ast
2+
from dataclasses import dataclass
23

34
from kirin import types, lowering
5+
from kirin.dialects import py
46

57
from . import stmts as ilist
68
from ._dialect import dialect
@@ -48,3 +50,43 @@ def lower_Call_IList(
4850
return state.current_frame.push(stmt)
4951
else:
5052
return self.lower_List(state, value)
53+
54+
55+
@dataclass(frozen=True)
56+
class SortedLowering(lowering.FromPythonCall["ilist.Sorted"]):
57+
"""
58+
Custom lowering for Sorted to deal with optional arguments `key=None` and `reverse=False`
59+
"""
60+
61+
def lower(
62+
self, stmt: type["ilist.Sorted"], state: lowering.State[ast.AST], node: ast.Call
63+
) -> lowering.Result:
64+
args = node.args
65+
66+
if len(args) != 1:
67+
raise lowering.BuildError(
68+
f"Expected single argument in sorted, got {len(args)}"
69+
)
70+
collection = state.lower(args[0]).expect_one()
71+
72+
key = None
73+
reverse = None
74+
for kwarg in node.keywords:
75+
if kwarg.arg == "key":
76+
key = state.lower(kwarg.value).expect_one()
77+
elif kwarg.arg == "reverse":
78+
reverse = state.lower(kwarg.value).expect_one()
79+
else:
80+
raise lowering.BuildError(
81+
f"Got unexpected keyword argument in sorted {kwarg.arg}"
82+
)
83+
84+
if key is None:
85+
key = state.current_frame.push(py.Constant(None)).result
86+
87+
if reverse is None:
88+
reverse = state.current_frame.push(py.Constant(False)).result
89+
90+
return state.current_frame.push(
91+
stmt(collection=collection, key=key, reverse=reverse)
92+
)

src/kirin/dialects/ilist/stmts.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from .runtime import IList
77
from ._dialect import dialect
8+
from .lowering import SortedLowering
89

910
ElemT = types.TypeVar("ElemT")
1011
ListLen = types.TypeVar("ListLen")
@@ -112,3 +113,15 @@ class ForEach(ir.Statement):
112113
purity: bool = info.attribute(default=False)
113114
fn: ir.SSAValue = info.argument(types.Generic(ir.Method, [ElemT], types.NoneType))
114115
collection: ir.SSAValue = info.argument(IListType[ElemT])
116+
117+
118+
@statement(dialect=dialect)
119+
class Sorted(ir.Statement):
120+
traits = frozenset({ir.MaybePure(), SortedLowering()})
121+
purity: bool = info.attribute(default=False)
122+
collection: ir.SSAValue = info.argument(IListType[ElemT, ListLen])
123+
key: ir.SSAValue = info.argument(
124+
types.Union((types.Generic(ir.Method, [ElemT], ElemT), types.NoneType))
125+
)
126+
reverse: ir.SSAValue = info.argument(types.Bool)
127+
result: ir.ResultValue = info.result(IListType[ElemT, ListLen])

test/dialects/test_ilist_wrapper.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,46 @@ def scan_wrap():
7575
10 + 1 + 1 + 1 + 3,
7676
10 + 1 + 1 + 1 + 1 + 4,
7777
]
78+
79+
80+
def test_sorted():
81+
def key_test(a: int) -> int:
82+
return a
83+
84+
@basic
85+
def main():
86+
ls = [2, 3, 1, 5, 4]
87+
return (
88+
ilist.sorted(ls),
89+
ilist.sorted(ls, key=key_test),
90+
ilist.sorted(ls, reverse=True),
91+
)
92+
93+
main.print()
94+
95+
ls1, ls2, ls3 = main()
96+
assert ls1.data == [1, 2, 3, 4, 5]
97+
assert ls2.data == ls1.data
98+
assert ls3.data == [5, 4, 3, 2, 1]
99+
100+
def first(x: tuple[str, int]) -> str:
101+
return x[0]
102+
103+
def second(x: tuple[str, int]) -> int:
104+
return x[1]
105+
106+
@basic
107+
def main2():
108+
ls = [("a", 4), ("b", 3), ("c", 1)]
109+
return (
110+
ilist.sorted(ls, key=first),
111+
ilist.sorted(ls, key=second),
112+
ilist.sorted(ls, key=second, reverse=True),
113+
)
114+
115+
main2.print()
116+
117+
ls1, ls2, ls3 = main2()
118+
assert ls1.data == [("a", 4), ("b", 3), ("c", 1)]
119+
assert ls3.data == ls1.data
120+
assert ls2.data == [("c", 1), ("b", 3), ("a", 4)]

0 commit comments

Comments
 (0)