Skip to content

Commit 19fc758

Browse files
committed
change return type of map to tuple instead of list
1 parent 84f217c commit 19fc758

File tree

7 files changed

+38
-15
lines changed

7 files changed

+38
-15
lines changed

src/kirin/dialects/fcf/interp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def foldr(self, interp: Interpreter, frame: Frame, stmt: Foldr):
3636
return (acc,)
3737

3838
@impl(Map)
39-
def map_list(self, interp: Interpreter, frame: Frame, stmt: Map):
39+
def map_tuple(self, interp: Interpreter, frame: Frame, stmt: Map):
4040
fn: ir.Method = frame.get(stmt.fn)
4141
coll = frame.get(stmt.coll)
4242
ret = []
@@ -47,7 +47,7 @@ def map_list(self, interp: Interpreter, frame: Frame, stmt: Map):
4747
return _ret
4848
else:
4949
ret.append(_ret)
50-
return (ret,)
50+
return (tuple(ret),)
5151

5252
@impl(Scan)
5353
def scan(self, interp: Interpreter, frame: Frame, stmt: Scan):

src/kirin/dialects/fcf/rewrite/fcfmap_inline.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from kirin import ir
55
from kirin.analysis import const
66
from kirin.dialects import fcf, func
7-
from kirin.exceptions import DialectInterpretationError
87
from kirin.dialects.py import stmts as py
98
from kirin.rewrite.abc import RewriteRule, RewriteResult
109
from kirin.ir.nodes.stmt import Statement
@@ -14,13 +13,6 @@
1413
class InlineFcfMap(RewriteRule):
1514
cp_results: Dict[ir.SSAValue, const.JointResult]
1615

17-
def get_const_value(self, p: ir.SSAValue):
18-
tmp = self.cp_results.get(p, None)
19-
if (tmp is None) or (not isinstance(tmp.const, const.Value)):
20-
raise DialectInterpretationError(f"not a const value: {p}")
21-
22-
return tmp.const.data
23-
2416
def rewrite_Statement(self, node: Statement) -> RewriteResult:
2517
match node:
2618
case fcf.Map():
@@ -30,7 +22,12 @@ def rewrite_Statement(self, node: Statement) -> RewriteResult:
3022

3123
def rewrite_fcf_map(self, node: fcf.Map) -> RewriteResult:
3224
# TODO make this more generic without the need for the constprop results
33-
coll = self.get_const_value(node.coll)
25+
tmp = self.cp_results.get(node.coll, None)
26+
27+
if (tmp is None) or (not isinstance(tmp.const, const.Value)):
28+
return RewriteResult()
29+
30+
coll = tmp.const.data
3431

3532
# rewrite to directly inline:
3633
# get the method:

src/kirin/dialects/fcf/stmts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class Foldr(ir.Statement):
2323
class Map(ir.Statement):
2424
fn: ir.SSAValue = info.argument(ir.types.PyClass(ir.Method))
2525
coll: ir.SSAValue = info.argument(ir.types.Any)
26-
result: ir.ResultValue = info.result(ir.types.List)
26+
result: ir.ResultValue = info.result(ir.types.Tuple)
2727

2828

2929
@statement(dialect=dialect)

test/dialects/fcf/test_fold.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,6 @@ def test_fold():
4343
assert foldl(xs) == sum(xs)
4444
assert foldr(xs) == sum(xs)
4545
assert map_list.return_type.is_subseteq(types.List[types.Float])
46-
assert map_list([1, 2, 3]) == [2.0, 3.0, 4.0]
46+
assert map_list([1, 2, 3]) == (2.0, 3.0, 4.0)
4747
assert scan([1, 2, 3, 4, 5]) == (15, [1, 2, 3, 4, 5])
4848
assert scan.return_type.is_subseteq(types.Tuple[types.Int, types.List[types.Int]])

test/dialects/fcf/test_map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ def map_func(i):
1111

1212

1313
def test_enumerate_kirin():
14-
assert enumerate_kirin([1, 2, 3, 4, 5]) == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]
14+
assert enumerate_kirin([1, 2, 3, 4, 5]) == ((0, 1), (1, 2), (2, 3), (3, 4), (4, 5))

test/dialects/pystmts/test_range.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ def map_range(x: range):
4444

4545
def test_map_range():
4646
assert map_range.return_type.is_subseteq(types.List[types.Float])
47-
assert map_range(range(10)) == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
47+
assert map_range(range(10)) == (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)

test/rules/test_fcfmap_inline.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from kirin.prelude import basic
22
from kirin.rewrite import Walk
3+
from kirin.analysis import const
34
from kirin.dialects import fcf
45
from kirin.analysis.const.prop import Propagate
56
from kirin.dialects.fcf.rewrite.fcfmap_inline import InlineFcfMap
@@ -25,3 +26,28 @@ def _simple(i: int):
2526
val = fcf_map_rewrite()
2627

2728
assert val == (0, 1, 2, 3, 4)
29+
30+
31+
def test_fcfmap_rewrite_with_arg():
32+
33+
@basic(fold=False)
34+
def fcf_map_rewrite_with_arg(x: int):
35+
36+
def _simple(i: int):
37+
return i
38+
39+
tmp = fcf.Map(_simple, (x, x + 1, x + 2))
40+
return tmp
41+
42+
fcf_map_rewrite_with_arg.code.print()
43+
cp = Propagate(dialects=fcf_map_rewrite_with_arg.dialects)
44+
cp.eval(fcf_map_rewrite_with_arg, (const.JointResult.top(),))
45+
Walk(InlineFcfMap(cp.results)).rewrite(fcf_map_rewrite_with_arg.code)
46+
fcf_map_rewrite_with_arg.code.print()
47+
48+
val = fcf_map_rewrite_with_arg(4)
49+
print(val)
50+
assert val == (4, 5, 6)
51+
52+
53+
test_fcfmap_rewrite_with_arg()

0 commit comments

Comments
 (0)