Skip to content

Commit 5f94e0a

Browse files
rafaelhacduck
authored andcommitted
Simple O(1) inlining rewrite (#552)
The inline pass (as part of a `Walk` rewrite) has O(n^2) time complexity which is a performance issue. An extra factor of `n` is coming from `inline.py` where each block is split in two, and the inline region is inserted in the middle. Splitting blocks in two comes with the extra O(n) factor: ```python after_block = ir.Block() stmt = call_like.next_stmt while stmt is not None: stmt.detach() after_block.stmts.append(stmt) stmt = call_like.next_stmt ``` This PR introduces a partial workaround. Simple regions with just a single block are inlined by inserting all of their statements directly. Since statements form a linked list, this is O(1). For my test case, I observe that this fix reduces runtime and brings the time complexity of the inline pass back to O(n). <img width="489" height="358" alt="image" src="https://github.com/user-attachments/assets/3b11560f-670f-45d3-84b0-0959693f47b4" /> However, we should refactor the inline pass to scale linearly even in the general case. --------- Co-authored-by: Casey Duckering <[email protected]>
1 parent 8f9e1ce commit 5f94e0a

File tree

1 file changed

+118
-22
lines changed

1 file changed

+118
-22
lines changed

src/kirin/rewrite/inline.py

Lines changed: 118 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable
1+
from typing import Callable, cast
22
from dataclasses import dataclass
33

44
from kirin import ir
@@ -78,6 +78,121 @@ def inline_call_like(
7878
args (tuple[ir.SSAValue, ...]): the arguments of the call (first one is the callee)
7979
region (ir.Region): the region of the callee
8080
"""
81+
if not call_like.parent_block:
82+
return
83+
84+
if not call_like.parent_region:
85+
return
86+
87+
# NOTE: we cannot change region because it may be used elsewhere
88+
inline_region: ir.Region = region.clone()
89+
90+
# Preserve source information by attributing inlined code to the call site
91+
if call_like.source is not None:
92+
for block in inline_region.blocks:
93+
if block.source is None:
94+
block.source = call_like.source
95+
for stmt in block.stmts:
96+
if stmt.source is None:
97+
stmt.source = call_like.source
98+
99+
if self._can_use_simple_inline(inline_region):
100+
return self._inline_simple(call_like, args, inline_region.blocks[0])
101+
102+
return self._inline_complex(call_like, args, inline_region)
103+
104+
def _can_use_simple_inline(self, inline_region: ir.Region) -> bool:
105+
"""Check if we can use the fast path for simple single-block inlining.
106+
107+
Args:
108+
inline_region: The cloned region to be inlined
109+
110+
Returns:
111+
True if simple inline is possible (single block with simple return)
112+
"""
113+
if len(inline_region.blocks) != 1:
114+
return False
115+
116+
block = inline_region.blocks[0]
117+
118+
# Last statement must be a simple return
119+
if not isinstance(block.last_stmt, func.Return):
120+
return False
121+
122+
return True
123+
124+
def _inline_simple(
125+
self,
126+
call_like: ir.Statement,
127+
args: tuple[ir.SSAValue, ...],
128+
func_block: ir.Block,
129+
):
130+
"""Fast path: inline single-block function by splicing statements.
131+
132+
For simple functions with no control flow, we just clone the function's
133+
statements and insert them before the call site.
134+
No new blocks are created, no statement parent updates are needed.
135+
136+
Complexity: O(k) where k = number of statements in function (typically small)
137+
138+
Args:
139+
call_like: The call statement to replace
140+
args: Arguments to the call (first is callee, rest are parameters)
141+
func_block: The single block from the cloned function region
142+
"""
143+
ssa_map: dict[ir.SSAValue, ir.SSAValue] = {}
144+
for func_arg, call_arg in zip(func_block.args, args):
145+
ssa_map[func_arg] = call_arg
146+
if func_arg.name and call_arg.name is None:
147+
call_arg.name = func_arg.name
148+
149+
for stmt in func_block.stmts:
150+
if isinstance(stmt, func.Return):
151+
return_value = ssa_map.get(stmt.value, stmt.value)
152+
153+
if call_like.results:
154+
for call_result in call_like.results:
155+
call_result.replace_by(return_value)
156+
157+
# Don't insert the return statement itself
158+
break
159+
160+
new_stmt = stmt.from_stmt(
161+
stmt,
162+
args=[ssa_map.get(arg, arg) for arg in stmt.args],
163+
regions=[r.clone(ssa_map) for r in stmt.regions],
164+
successors=stmt.successors, # successors are empty for simple stmts
165+
)
166+
167+
new_stmt.insert_before(call_like)
168+
169+
# Update SSA mapping for newly created results
170+
for old_result, new_result in zip(stmt.results, new_stmt.results):
171+
ssa_map[old_result] = new_result
172+
if old_result.name:
173+
new_result.name = old_result.name
174+
175+
call_like.delete()
176+
177+
def _inline_complex(
178+
self,
179+
call_like: ir.Statement,
180+
args: tuple[ir.SSAValue, ...],
181+
inline_region: ir.Region,
182+
):
183+
"""Inline multi-block function with control flow.
184+
185+
This handles the general case where the function has multiple blocks
186+
187+
Complexity: O(n+k) where n = statements after call site (due to moving them)
188+
and k = number of statements in function.
189+
190+
Args:
191+
call_like: The call statement to replace
192+
args: Arguments to the call
193+
inline_region: The cloned function region to inline
194+
"""
195+
81196
# <stmt>
82197
# <stmt>
83198
# <br (a, b, c)>
@@ -99,26 +214,8 @@ def inline_call_like(
99214
# split the current block into two, and replace the return with
100215
# the branch instruction
101216
# 4. remove the call
102-
if not call_like.parent_block:
103-
return
104-
105-
if not call_like.parent_region:
106-
return
107-
108-
# NOTE: we cannot change region because it may be used elsewhere
109-
inline_region: ir.Region = region.clone()
110-
111-
# Preserve source information by attributing inlined code to the call site
112-
if call_like.source is not None:
113-
for block in inline_region.blocks:
114-
if block.source is None:
115-
block.source = call_like.source
116-
for stmt in block.stmts:
117-
if stmt.source is None:
118-
stmt.source = call_like.source
119-
120-
parent_block: ir.Block = call_like.parent_block
121-
parent_region: ir.Region = call_like.parent_region
217+
parent_block: ir.Block = cast(ir.Block, call_like.parent_block)
218+
parent_region: ir.Region = cast(ir.Region, call_like.parent_region)
122219

123220
# wrap what's after invoke into a block
124221
after_block = ir.Block()
@@ -160,4 +257,3 @@ def inline_call_like(
160257
successor=entry_block,
161258
).insert_before(call_like)
162259
call_like.delete()
163-
return

0 commit comments

Comments
 (0)