Skip to content

Commit 163f625

Browse files
kaihsinRoger-luo
andauthored
fcf.Map Inline (#130)
Co-authored-by: Xiu-zhe (Roger) Luo <[email protected]>
1 parent bf5052d commit 163f625

File tree

8 files changed

+109
-6
lines changed

8 files changed

+109
-6
lines changed

src/kirin/dialects/fcf/interp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def foldr(self, interp: Interpreter, frame: Frame, stmt: Foldr):
3636
return (acc,)
3737

3838
@impl(Map)
39-
def map_list(self, interp: Interpreter, frame: Frame, stmt: Map):
39+
def map_tuple(self, interp: Interpreter, frame: Frame, stmt: Map):
4040
fn: ir.Method = frame.get(stmt.fn)
4141
coll = frame.get(stmt.coll)
4242
ret = []
@@ -47,7 +47,7 @@ def map_list(self, interp: Interpreter, frame: Frame, stmt: Map):
4747
return _ret
4848
else:
4949
ret.append(_ret)
50-
return (ret,)
50+
return (tuple(ret),)
5151

5252
@impl(Scan)
5353
def scan(self, interp: Interpreter, frame: Frame, stmt: Scan):

src/kirin/dialects/fcf/rewrite/__init__.py

Whitespace-only changes.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Dict
2+
from dataclasses import dataclass
3+
4+
from kirin import ir
5+
from kirin.analysis import const
6+
from kirin.dialects import py, fcf, func
7+
from kirin.rewrite.abc import RewriteRule, RewriteResult
8+
from kirin.ir.nodes.stmt import Statement
9+
10+
11+
@dataclass
12+
class InlineFcfMap(RewriteRule):
13+
cp_results: Dict[ir.SSAValue, const.JointResult]
14+
15+
def rewrite_Statement(self, node: Statement) -> RewriteResult:
16+
match node:
17+
case fcf.Map():
18+
return self.rewrite_fcf_map(node)
19+
case _:
20+
return RewriteResult()
21+
22+
def rewrite_fcf_map(self, node: fcf.Map) -> RewriteResult:
23+
# TODO make this more generic without the need for the constprop results
24+
tmp = self.cp_results.get(node.coll, None)
25+
26+
if (tmp is None) or (not isinstance(tmp.const, const.Value)):
27+
return RewriteResult()
28+
29+
coll = tmp.const.data
30+
31+
# rewrite to directly inline:
32+
# get the method:
33+
tpl_elem = []
34+
curr = node
35+
for i in coll:
36+
new_c = py.Constant(value=i)
37+
newstmt = func.Call(callee=node.fn, inputs=(new_c.result,), kwargs=())
38+
39+
new_c.insert_after(curr)
40+
newstmt.insert_after(new_c)
41+
tpl_elem.append(newstmt.result)
42+
curr = newstmt
43+
44+
# assemble tuple:
45+
tpl = py.tuple.New(values=tuple(tpl_elem))
46+
tpl.insert_after(curr)
47+
node.result.replace_by(tpl.result)
48+
node.delete()
49+
50+
return RewriteResult(has_done_something=True)

src/kirin/dialects/fcf/stmts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def main():
4646
"""The kernel function to apply. The function should have signature `fn(x: int) -> Any`."""
4747
coll: ir.SSAValue = info.argument(ir.types.Any)
4848
"""The iterable to map over."""
49-
result: ir.ResultValue = info.result(ir.types.List)
49+
result: ir.ResultValue = info.result(ir.types.Tuple)
5050
"""The list of results."""
5151

5252

test/dialects/fcf/test_fold.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,6 @@ def test_fold():
4343
assert foldl(xs) == sum(xs)
4444
assert foldr(xs) == sum(xs)
4545
assert map_list.return_type.is_subseteq(types.List[types.Float])
46-
assert map_list([1, 2, 3]) == [2.0, 3.0, 4.0]
46+
assert map_list([1, 2, 3]) == (2.0, 3.0, 4.0)
4747
assert scan([1, 2, 3, 4, 5]) == (15, [1, 2, 3, 4, 5])
4848
assert scan.return_type.is_subseteq(types.Tuple[types.Int, types.List[types.Int]])

test/dialects/fcf/test_map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ def map_func(i):
1111

1212

1313
def test_enumerate_kirin():
14-
assert enumerate_kirin([1, 2, 3, 4, 5]) == [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]
14+
assert enumerate_kirin([1, 2, 3, 4, 5]) == ((0, 1), (1, 2), (2, 3), (3, 4), (4, 5))

test/dialects/pystmts/test_range.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ def map_range(x: range):
4444

4545
def test_map_range():
4646
assert map_range.return_type.is_subseteq(types.List[types.Float])
47-
assert map_range(range(10)) == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
47+
assert map_range(range(10)) == (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)

test/rules/test_fcfmap_inline.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from kirin.prelude import basic
2+
from kirin.rewrite import Walk
3+
from kirin.analysis import const
4+
from kirin.dialects import fcf
5+
from kirin.analysis.const.prop import Propagate
6+
from kirin.dialects.fcf.rewrite.fcfmap_inline import InlineFcfMap
7+
8+
9+
def test_fcfmap_rewrite():
10+
11+
@basic(fold=False)
12+
def fcf_map_rewrite():
13+
14+
def _simple(i: int):
15+
return i
16+
17+
tmp = fcf.Map(_simple, range(5))
18+
return tmp
19+
20+
fcf_map_rewrite.code.print()
21+
cp = Propagate(dialects=fcf_map_rewrite.dialects)
22+
cp.eval(fcf_map_rewrite, ())
23+
Walk(InlineFcfMap(cp.results)).rewrite(fcf_map_rewrite.code)
24+
fcf_map_rewrite.code.print()
25+
26+
val = fcf_map_rewrite()
27+
28+
assert val == (0, 1, 2, 3, 4)
29+
30+
31+
def test_fcfmap_rewrite_with_arg():
32+
33+
@basic(fold=False)
34+
def fcf_map_rewrite_with_arg(x: int):
35+
36+
def _simple(i: int):
37+
return i
38+
39+
tmp = fcf.Map(_simple, (x, x + 1, x + 2))
40+
return tmp
41+
42+
fcf_map_rewrite_with_arg.code.print()
43+
cp = Propagate(dialects=fcf_map_rewrite_with_arg.dialects)
44+
cp.eval(fcf_map_rewrite_with_arg, (const.JointResult.top(),))
45+
Walk(InlineFcfMap(cp.results)).rewrite(fcf_map_rewrite_with_arg.code)
46+
fcf_map_rewrite_with_arg.code.print()
47+
48+
val = fcf_map_rewrite_with_arg(4)
49+
print(val)
50+
assert val == (4, 5, 6)
51+
52+
53+
test_fcfmap_rewrite_with_arg()

0 commit comments

Comments
 (0)