Skip to content

Commit a664b05

Browse files
rafaelhacduck
andauthored
Simple O(1) inlining rewrite (#552)
The inline pass (as part of a `Walk` rewrite) has O(n^2) time complexity which is a performance issue. An extra factor of `n` is coming from `inline.py` where each block is split in two, and the inline region is inserted in the middle. Splitting blocks in two comes with the extra O(n) factor: ```python after_block = ir.Block() stmt = call_like.next_stmt while stmt is not None: stmt.detach() after_block.stmts.append(stmt) stmt = call_like.next_stmt ``` This PR introduces a partial workaround. Simple regions with just a single block are inlined by inserting all of their statements directly. Since statements form a linked list, this is O(1). For my test case, I observe that this fix reduces runtime and brings the time complexity of the inline pass back to O(n). <img width="489" height="358" alt="image" src="https://github.com/user-attachments/assets/3b11560f-670f-45d3-84b0-0959693f47b4" /> However, we should refactor the inline pass to scale linearly even in the general case. --------- Co-authored-by: Casey Duckering <[email protected]>
1 parent f5bee94 commit a664b05

File tree

1 file changed

+118
-22
lines changed

1 file changed

+118
-22
lines changed

src/kirin/rewrite/inline.py

Lines changed: 118 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable
1+
from typing import Callable, cast
22
from dataclasses import dataclass
33

44
from kirin import ir
@@ -68,6 +68,121 @@ def inline_call_like(
6868
args (tuple[ir.SSAValue, ...]): the arguments of the call (first one is the callee)
6969
region (ir.Region): the region of the callee
7070
"""
71+
if not call_like.parent_block:
72+
return
73+
74+
if not call_like.parent_region:
75+
return
76+
77+
# NOTE: we cannot change region because it may be used elsewhere
78+
inline_region: ir.Region = region.clone()
79+
80+
# Preserve source information by attributing inlined code to the call site
81+
if call_like.source is not None:
82+
for block in inline_region.blocks:
83+
if block.source is None:
84+
block.source = call_like.source
85+
for stmt in block.stmts:
86+
if stmt.source is None:
87+
stmt.source = call_like.source
88+
89+
if self._can_use_simple_inline(inline_region):
90+
return self._inline_simple(call_like, args, inline_region.blocks[0])
91+
92+
return self._inline_complex(call_like, args, inline_region)
93+
94+
def _can_use_simple_inline(self, inline_region: ir.Region) -> bool:
95+
"""Check if we can use the fast path for simple single-block inlining.
96+
97+
Args:
98+
inline_region: The cloned region to be inlined
99+
100+
Returns:
101+
True if simple inline is possible (single block with simple return)
102+
"""
103+
if len(inline_region.blocks) != 1:
104+
return False
105+
106+
block = inline_region.blocks[0]
107+
108+
# Last statement must be a simple return
109+
if not isinstance(block.last_stmt, func.Return):
110+
return False
111+
112+
return True
113+
114+
def _inline_simple(
115+
self,
116+
call_like: ir.Statement,
117+
args: tuple[ir.SSAValue, ...],
118+
func_block: ir.Block,
119+
):
120+
"""Fast path: inline single-block function by splicing statements.
121+
122+
For simple functions with no control flow, we just clone the function's
123+
statements and insert them before the call site.
124+
No new blocks are created, no statement parent updates are needed.
125+
126+
Complexity: O(k) where k = number of statements in function (typically small)
127+
128+
Args:
129+
call_like: The call statement to replace
130+
args: Arguments to the call (first is callee, rest are parameters)
131+
func_block: The single block from the cloned function region
132+
"""
133+
ssa_map: dict[ir.SSAValue, ir.SSAValue] = {}
134+
for func_arg, call_arg in zip(func_block.args, args):
135+
ssa_map[func_arg] = call_arg
136+
if func_arg.name and call_arg.name is None:
137+
call_arg.name = func_arg.name
138+
139+
for stmt in func_block.stmts:
140+
if isinstance(stmt, func.Return):
141+
return_value = ssa_map.get(stmt.value, stmt.value)
142+
143+
if call_like.results:
144+
for call_result in call_like.results:
145+
call_result.replace_by(return_value)
146+
147+
# Don't insert the return statement itself
148+
break
149+
150+
new_stmt = stmt.from_stmt(
151+
stmt,
152+
args=[ssa_map.get(arg, arg) for arg in stmt.args],
153+
regions=[r.clone(ssa_map) for r in stmt.regions],
154+
successors=stmt.successors, # successors are empty for simple stmts
155+
)
156+
157+
new_stmt.insert_before(call_like)
158+
159+
# Update SSA mapping for newly created results
160+
for old_result, new_result in zip(stmt.results, new_stmt.results):
161+
ssa_map[old_result] = new_result
162+
if old_result.name:
163+
new_result.name = old_result.name
164+
165+
call_like.delete()
166+
167+
def _inline_complex(
168+
self,
169+
call_like: ir.Statement,
170+
args: tuple[ir.SSAValue, ...],
171+
inline_region: ir.Region,
172+
):
173+
"""Inline multi-block function with control flow.
174+
175+
This handles the general case where the function has multiple blocks
176+
177+
Complexity: O(n+k) where n = statements after call site (due to moving them)
178+
and k = number of statements in function.
179+
180+
Args:
181+
call_like: The call statement to replace
182+
args: Arguments to the call
183+
inline_region: The cloned function region to inline
184+
"""
185+
71186
# <stmt>
72187
# <stmt>
73188
# <br (a, b, c)>
@@ -89,26 +204,8 @@ def inline_call_like(
89204
# split the current block into two, and replace the return with
90205
# the branch instruction
91206
# 4. remove the call
92-
if not call_like.parent_block:
93-
return
94-
95-
if not call_like.parent_region:
96-
return
97-
98-
# NOTE: we cannot change region because it may be used elsewhere
99-
inline_region: ir.Region = region.clone()
100-
101-
# Preserve source information by attributing inlined code to the call site
102-
if call_like.source is not None:
103-
for block in inline_region.blocks:
104-
if block.source is None:
105-
block.source = call_like.source
106-
for stmt in block.stmts:
107-
if stmt.source is None:
108-
stmt.source = call_like.source
109-
110-
parent_block: ir.Block = call_like.parent_block
111-
parent_region: ir.Region = call_like.parent_region
207+
parent_block: ir.Block = cast(ir.Block, call_like.parent_block)
208+
parent_region: ir.Region = cast(ir.Region, call_like.parent_region)
112209

113210
# wrap what's after invoke into a block
114211
after_block = ir.Block()
@@ -150,4 +247,3 @@ def inline_call_like(
150247
successor=entry_block,
151248
).insert_before(call_like)
152249
call_like.delete()
153-
return

0 commit comments

Comments
 (0)