Skip to content

Commit dd8f7db

Browse files
authored
backports 511 and 518 (#519)
1 parent 26403af commit dd8f7db

File tree

4 files changed

+41
-4
lines changed

4 files changed

+41
-4
lines changed

src/kirin/dialects/py/constant.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,13 @@ class Constant(ir.Statement, Generic[T]):
3232

3333
# NOTE: we allow py.Constant take data.PyAttr too
3434
def __init__(self, value: T | ir.Data[T]) -> None:
35-
if not isinstance(value, ir.Data):
35+
if isinstance(value, ir.Method):
36+
value = ir.PyAttr(
37+
value, pytype=types.MethodType[list(value.arg_types), value.return_type]
38+
)
39+
elif not isinstance(value, ir.Data):
3640
value = ir.PyAttr(value)
41+
3742
super().__init__(
3843
attributes={"value": value},
3944
result_types=(value.type,),

src/kirin/passes/aggressive/unroll.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@ def unsafe_run(self, mt: Method):
3030
result = RewriteResult()
3131
result = Walk(PickIfElse()).rewrite(mt.code).join(result)
3232
result = Walk(ForLoop()).rewrite(mt.code).join(result)
33-
result = self.typeinfer(mt).join(result)
34-
result = self.fold(mt).join(result)
33+
result = self.fold.unsafe_run(mt).join(result)
34+
self.typeinfer.unsafe_run(mt)
3535
return result
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from kirin.passes import TypeInfer
2+
from kirin.prelude import basic_no_opt
3+
4+
5+
def test_always_rewrites():
6+
@basic_no_opt
7+
def unstable(x: int): # type: ignore
8+
y = x + 1
9+
if y > 10:
10+
z = y
11+
else:
12+
z = y + 1.2
13+
return z
14+
15+
result = TypeInfer(dialects=unstable.dialects, no_raise=False).fixpoint(unstable)
16+
assert (
17+
result.has_done_something
18+
) # this will always be true because TypeInfer always rewrites type

test/analysis/dataflow/typeinfer/test_inter_method.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from kirin import types
2-
from kirin.prelude import basic
2+
from kirin.prelude import basic, structural
3+
from kirin.dialects import ilist
34

45

56
@basic
@@ -28,3 +29,16 @@ def test_inter_method_infer():
2829
assert foo.arg_types[0] == types.Int
2930
assert foo.inferred is False
3031
assert foo.return_type is types.Any
32+
33+
34+
def test_method_constant_type_infer():
35+
36+
@structural(typeinfer=True, fold=False)
37+
def _new(qid: int):
38+
return 1
39+
40+
@structural(fold=False, typeinfer=True)
41+
def alloc(n_iter: int):
42+
return ilist.map(_new, ilist.range(n_iter))
43+
44+
assert alloc.return_type.is_subseteq(ilist.IListType[types.Literal(1), types.Any])

0 commit comments

Comments
 (0)