Skip to content

Commit 94e86b7

Browse files
neelay893Neelay FruitwalaNeelay Fruitwala
authored
patch aggressive inline for v0.17 (#513)
Fix #508 (this is fixed in main but needs a patch for release-0-17). Add logic in the `Inline` rewrite rule to permute kwargs. --------- Co-authored-by: Neelay Fruitwala <[email protected]> Co-authored-by: Neelay Fruitwala <[email protected]>
1 parent 2248425 commit 94e86b7

File tree

2 files changed

+102
-3
lines changed

2 files changed

+102
-3
lines changed

src/kirin/rewrite/inline.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass
33

44
from kirin import ir
5+
from kirin.interp import BaseInterpreter
56
from kirin.dialects import cf, func
67
from kirin.rewrite.abc import RewriteRule, RewriteResult
78

@@ -30,7 +31,13 @@ def rewrite_func_Call(self, node: func.Call) -> RewriteResult:
3031
return RewriteResult()
3132

3233
# NOTE: a lambda statement is defined and used in the same scope
33-
self.inline_call_like(node, tuple(node.args), lambda_stmt.body)
34+
arg_names = [arg.name for arg in node.callee.owner.body.blocks[0].args]
35+
args = BaseInterpreter.permute_values(
36+
arg_names=arg_names,
37+
values=tuple(node.args[1:]),
38+
kwarg_names=node.kwargs,
39+
)
40+
self.inline_call_like(node, (node.args[0],) + args, lambda_stmt.body)
3441
return RewriteResult(has_done_something=True)
3542

3643
def rewrite_func_Invoke(self, node: func.Invoke) -> RewriteResult:
@@ -47,9 +54,12 @@ def rewrite_func_Invoke(self, node: func.Invoke) -> RewriteResult:
4754
func_self = Constant(node.callee)
4855
func_self.result.name = node.callee.sym_name
4956
func_self.insert_before(node)
50-
self.inline_call_like(
51-
node, (func_self.result,) + tuple(arg for arg in node.args), region
57+
args = BaseInterpreter.permute_values(
58+
arg_names=node.callee.arg_names,
59+
values=tuple(node.args),
60+
kwarg_names=node.kwargs,
5261
)
62+
self.inline_call_like(node, (func_self.result,) + tuple(args), region)
5363
has_done_something = True
5464

5565
return RewriteResult(has_done_something=has_done_something)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from kirin.passes import aggressive
2+
from kirin.prelude import basic
3+
4+
5+
def test_aggressive_inline():
6+
7+
@basic(aggressive=False)
8+
def foo0(arg0, arg1):
9+
return arg0 - arg1
10+
11+
@basic(aggressive=False)
12+
def main_aggressive(arg0):
13+
return foo0(arg1=2, arg0=arg0)
14+
15+
main_aggressive = main_aggressive.similar()
16+
aggressive.Fold(main_aggressive.dialects).fixpoint(main_aggressive)
17+
18+
assert main_aggressive(1) == -1
19+
20+
21+
def test_aggressive_inline_noargs():
22+
23+
@basic(aggressive=False)
24+
def foo1(arg0, arg1):
25+
return arg0 - arg1
26+
27+
@basic(aggressive=True)
28+
def main_aggressive():
29+
return foo1(arg1=2, arg0=1)
30+
31+
assert main_aggressive() == -1
32+
33+
34+
def test_aggressive_inline_pos_args():
35+
36+
@basic(aggressive=False)
37+
def foo2(arg0, arg1):
38+
return arg0 - arg1
39+
40+
@basic(aggressive=True)
41+
def main_aggressive(arg0):
42+
return foo2(arg0, 2)
43+
44+
assert main_aggressive(1) == -1
45+
46+
47+
def test_aggressive_inline_closure():
48+
49+
# @basic(aggressive=False, fold=False, typeinfer=True)
50+
@basic
51+
def main_aggressive(param: int):
52+
def foo3(arg0: int, arg1: int):
53+
return arg0 - arg1 + param
54+
55+
return foo3(arg1=2, arg0=1)
56+
57+
main_aggressive = main_aggressive.similar()
58+
aggressive.Fold(main_aggressive.dialects).fixpoint(main_aggressive)
59+
60+
assert main_aggressive(1) == 0
61+
62+
63+
def test_aggressive_inline_closure_pos_args():
64+
65+
# @basic(aggressive=False, fold=False, typeinfer=True)
66+
@basic
67+
def main_aggressive(param: int):
68+
def foo3(arg0: int, arg1: int):
69+
return arg0 - arg1 + param
70+
71+
return foo3(1, arg1=2)
72+
73+
main_aggressive = main_aggressive.similar()
74+
aggressive.Fold(main_aggressive.dialects).fixpoint(main_aggressive)
75+
76+
assert main_aggressive(1) == 0
77+
78+
79+
def test_aggressive_inline_closure_alias():
80+
@basic(aggressive=True)
81+
def main_aggressive2(param: int):
82+
def foo4(arg0: int, arg1: int):
83+
return arg0 - arg1 + param
84+
85+
alias_foo4 = foo4
86+
87+
return alias_foo4(arg1=2, arg0=1)
88+
89+
assert main_aggressive2(1) == 0

0 commit comments

Comments
 (0)