Skip to content

Commit 3c8793b

Browse files
authored
revise prelude dialect groups (#148)
this PR revise the prelude dialect groups as following: `basic` and `basic_no_opt` are the basic language with immutable list and anything that is common in eDSLs in the future. `python_basic` and `python_no_opt` are the exact python language with mutable list and anything that is Python oriented in the future. the main issue here is to have a separation where those hard to compile but easy to interpret Python features should go.
1 parent adfd236 commit 3c8793b

File tree

11 files changed

+100
-26
lines changed

11 files changed

+100
-26
lines changed

src/kirin/dialects/ilist/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@
2121
ForEach as ForEach,
2222
IListType as IListType,
2323
)
24+
from .passes import IListDesugar as IListDesugar
2425
from .runtime import IList as IList
2526
from ._dialect import dialect as dialect

src/kirin/dialects/ilist/passes.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from kirin.ir import Method, types
2+
from kirin.rewrite import Walk, Fixpoint
3+
from kirin.passes.abc import Pass
4+
from kirin.rewrite.result import RewriteResult
5+
from kirin.dialects.ilist.rewrite import List2IList
6+
7+
8+
class IListDesugar(Pass):
9+
"""This pass desugars the Python list dialect
10+
to the immutable list dialect by rewriting all
11+
constant `list` type into `IList` type.
12+
"""
13+
14+
def unsafe_run(self, mt: Method) -> RewriteResult:
15+
for arg in mt.args:
16+
_check_list(arg.type, arg.type)
17+
return Fixpoint(Walk(List2IList())).rewrite(mt.code)
18+
19+
20+
def _check_list(total: types.TypeAttribute, type_: types.TypeAttribute):
21+
if isinstance(type_, types.Generic):
22+
_check_list(total, type_.body)
23+
for var in type_.vars:
24+
_check_list(total, var)
25+
if type_.vararg:
26+
_check_list(total, type_.vararg.typ)
27+
elif isinstance(type_, types.PyClass):
28+
if issubclass(type_.typ, list):
29+
raise TypeError(
30+
f"Invalid type {total} for this kernel, use IList instead of {type_}."
31+
)
32+
return

src/kirin/dialects/ilist/rewrite/list.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,15 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
1919
return RewriteResult(has_done_something=has_done_something)
2020

2121
def _rewrite_SSAValue_type(self, value: ir.SSAValue):
22-
if value.type.is_subseteq(ir.types.List):
23-
if isinstance(value.type, ir.types.Generic):
24-
value.type = IListType[value.type.vars[0], ir.types.Any]
25-
else:
26-
value.type = IListType[ir.types.Any, ir.types.Any]
22+
# NOTE: cannot use issubseteq here because type can be Bottom
23+
if isinstance(value.type, ir.types.Generic) and issubclass(
24+
value.type.body.typ, list
25+
):
26+
value.type = IListType[value.type.vars[0], ir.types.Any]
27+
return True
28+
elif isinstance(value.type, ir.types.PyClass) and issubclass(
29+
value.type.typ, list
30+
):
31+
value.type = IListType[ir.types.Any, ir.types.Any]
2732
return True
2833
return False

src/kirin/dialects/ilist/runtime.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,17 @@ def __str__(self) -> str:
4040
return f"IList({self.data})"
4141

4242
def __iter__(self):
43-
raise NotImplementedError("Cannot use IList outside kernel.")
43+
return iter(self.data)
4444

45-
def __getitem__(self, index: int) -> T:
45+
@overload
46+
def __getitem__(self, index: slice) -> "IList[T, Any]": ...
47+
48+
@overload
49+
def __getitem__(self, index: int) -> T: ...
50+
51+
def __getitem__(self, index: int | slice) -> T | "IList[T, Any]":
52+
if isinstance(index, slice):
53+
return IList(self.data[index])
4654
return self.data[index]
4755

4856
def __eq__(self, value: object) -> bool:

src/kirin/ir/attrs/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ def is_subseteq_TypeVar(self, other: "TypeVar") -> bool:
169169
def __hash__(self) -> int:
170170
return hash((PyClass, self.typ))
171171

172+
def __repr__(self) -> str:
173+
return self.typ.__name__
174+
172175
def print_impl(self, printer: Printer) -> None:
173176
printer.plain_print("!py.", self.typ.__name__)
174177

@@ -389,6 +392,12 @@ def is_subseteq_Generic(self, other: "Generic") -> bool:
389392
def __hash__(self) -> int:
390393
return hash((Generic, self.body, self.vars, self.vararg))
391394

395+
def __repr__(self) -> str:
396+
if self.vararg is None:
397+
return f"{self.body}[{', '.join(map(repr, self.vars))}]"
398+
else:
399+
return f"{self.body}[{', '.join(map(repr, self.vars))}, {self.vararg}, ...]"
400+
392401
def print_impl(self, printer: Printer) -> None:
393402
printer.print(self.body)
394403
printer.plain_print("[")

src/kirin/prelude.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from kirin.ir import Method, dialect_group
44
from kirin.passes import aggressive
5-
from kirin.dialects import cf, func, math
5+
from kirin.dialects import cf, func, math, ilist
66
from kirin.dialects.py import (
77
cmp,
88
len,
@@ -49,7 +49,7 @@ def run_pass(mt: Method) -> None:
4949
return run_pass
5050

5151

52-
@dialect_group(python_basic.union([list, range, slice]))
52+
@dialect_group(python_basic.union([list, range, slice, cf, func, math]))
5353
def python_no_opt(self):
5454
"""The Python dialect without optimization passes."""
5555

@@ -59,19 +59,27 @@ def run_pass(mt: Method) -> None:
5959
return run_pass
6060

6161

62-
@dialect_group(python_no_opt.union([cf, func, math]))
62+
@dialect_group(python_basic.union([ilist, range, slice, cf, func, math]))
6363
def basic_no_opt(self):
64-
"""The basic kernel without optimization passes.
64+
"""The basic kernel without optimization passes. This is a builtin
65+
eDSL that includes the basic dialects that are commonly used in
66+
Python-like eDSLs.
6567
6668
This eDSL includes the basic dialects without any optimization passes.
6769
Other eDSL can usually be built on top of this eDSL by utilizing the
6870
`basic_no_opt.add` method to add more dialects and optimization passes.
6971
72+
Note that unlike Python, list in this eDSL is immutable, and the
73+
`append` method is not available. Use `+` operator to concatenate lists
74+
instead. Immutable list is easier to optimize and reason about.
75+
7076
See also [`basic`][kirin.prelude.basic] for the basic kernel with optimization passes.
77+
See also [`ilist`][kirin.dialects.ilist] for the immutable list dialect.
7178
"""
79+
ilist_desugar = ilist.IListDesugar(self)
7280

7381
def run_pass(mt: Method) -> None:
74-
pass
82+
ilist_desugar.fixpoint(mt)
7583

7684
return run_pass
7785

@@ -99,6 +107,7 @@ def main(x: int) -> int:
99107
```
100108
"""
101109
fold_pass = Fold(self)
110+
ilist_desugar = ilist.IListDesugar(self)
102111
aggressive_fold_pass = aggressive.Fold(self)
103112
typeinfer_pass = TypeInfer(self)
104113

@@ -122,6 +131,8 @@ def run_pass(
122131
if verify:
123132
mt.verify()
124133

134+
ilist_desugar.fixpoint(mt)
135+
125136
if fold:
126137
if aggressive:
127138
aggressive_fold_pass.fixpoint(mt)

test/analysis/dataflow/test_constprop.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from kirin.decl import info, statement
44
from kirin.prelude import basic_no_opt
55
from kirin.analysis import const
6-
from kirin.dialects import py
6+
from kirin.dialects import ilist
77

88

99
class TestLattice:
@@ -279,10 +279,9 @@ def side_effect_true_branch_const(cond: bool):
279279

280280
def test_non_pure_recursion():
281281
@basic_no_opt
282-
def for_loop_append(cntr: int, x: list, n_range: int):
282+
def for_loop_append(cntr: int, x: ilist.IList, n_range: int):
283283
if cntr < n_range:
284-
py.Append(x, cntr) # type: ignore
285-
for_loop_append(cntr + 1, x, n_range)
284+
for_loop_append(cntr + 1, x + [cntr], n_range)
286285

287286
return x
288287

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from kirin.ir import types
22
from kirin.prelude import basic
3+
from kirin.dialects.ilist import IList, IListType
34

45

56
@basic(typeinfer=True)
@@ -8,10 +9,10 @@ def tuple_new(x: int, xs: tuple):
89

910

1011
@basic(typeinfer=True)
11-
def list_new(x: int, xs: list):
12+
def list_new(x: int, xs: IList):
1213
return xs + [1, x]
1314

1415

1516
def test_tuple_add():
1617
assert tuple_new.return_type.is_subseteq(types.Tuple)
17-
assert list_new.return_type.is_subseteq(types.List)
18+
assert list_new.return_type.is_subseteq(IListType)

test/dialects/pystmts/test_getitem.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from typing import Any
2+
13
from kirin.ir import types
24
from kirin.prelude import basic
5+
from kirin.dialects.ilist import IList, IListType
36

47

58
@basic(typeinfer=True)
@@ -43,12 +46,12 @@ def tuple_const_err(xs: tuple[int, float, str]):
4346

4447

4548
@basic(typeinfer=True)
46-
def list_infer(xs: list[int], i: int):
49+
def list_infer(xs: IList[int, Any], i: int):
4750
return xs[i]
4851

4952

5053
@basic(typeinfer=True)
51-
def list_slice(xs: list[int], i: slice):
54+
def list_slice(xs: IList[int, Any], i: slice):
5255
return xs[i]
5356

5457

@@ -73,5 +76,5 @@ def test_getitem_typeinfer():
7376
assert tuple_err.return_type.is_equal(types.Bottom)
7477
assert tuple_const_err.return_type.is_equal(types.Bottom)
7578
assert list_infer.return_type.is_subseteq(types.Int)
76-
assert list_slice.return_type.is_subseteq(types.List[types.Int])
79+
assert list_slice.return_type.is_subseteq(IListType[types.Int])
7780
assert unknown.return_type.is_equal(types.Any)

test/program/py/test_list_append.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# type: ignore
2-
from kirin.prelude import basic
2+
from kirin.prelude import python_no_opt
33
from kirin.dialects import py
44

55

66
def test_list_append():
77

8-
@basic
8+
@python_no_opt
99
def test_append():
1010
x = []
1111
py.Append(x, 1)
@@ -20,7 +20,7 @@ def test_append():
2020

2121

2222
def test_recursive_append():
23-
@basic
23+
@python_no_opt
2424
def for_loop_append(cntr: int, x: list, n_range: int):
2525
if cntr < n_range:
2626
py.Append(x, cntr)

0 commit comments

Comments
 (0)