Skip to content

Commit 0b79d67

Browse files
authored
refactor interp to clean up APIs + improve stacktrace (#369)
this PR cleans up the double `new_frame` API in the interpreter framework (closes #363). We also refactor the `run_ssacfg_region` and deprecate `run_block` interface to make linter happier (it was not consistent in `AbstractInterpreter` vs `BaseInterpreter`). We still need some further thinking on what is a value of `region` and `block` (determine by their terminating statement/block?) so that codegen framework can share more codebase with interpreter.
1 parent d7cc18d commit 0b79d67

File tree

19 files changed

+231
-183
lines changed

19 files changed

+231
-183
lines changed

src/kirin/analysis/const/prop.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,18 @@ def initialize(self):
5252
self._interp.initialize()
5353
return self
5454

55-
def new_frame(self, code: ir.Statement) -> Frame:
56-
return Frame.from_func_like(code)
55+
def initialize_frame(
56+
self, code: ir.Statement, *, has_parent_access: bool = False
57+
) -> Frame:
58+
return Frame(code, has_parent_access=has_parent_access)
5759

5860
def try_eval_const_pure(
5961
self,
6062
frame: Frame,
6163
stmt: ir.Statement,
6264
values: tuple[Value, ...],
6365
) -> interp.StatementResult[Result]:
64-
_frame = self._interp.new_frame(frame.code)
66+
_frame = self._interp.initialize_frame(frame.code)
6567
_frame.set_values(stmt.args, tuple(x.data for x in values))
6668
method = self._interp.lookup_registry(frame, stmt)
6769
if method is not None:

src/kirin/analysis/forward.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def run_analysis(
6969
# so we don't need to copy the frames.
7070
if not no_raise:
7171
raise e
72-
return self.new_frame(method.code), self.lattice.bottom()
72+
return self.state.current_frame, self.lattice.bottom()
7373
finally:
7474
self._eval_lock = False
7575
sys.setrecursionlimit(current_recursion_limit)
@@ -103,5 +103,7 @@ class Forward(ForwardExtra[ForwardFrame[LatticeElemType], LatticeElemType], ABC)
103103
[`ForwardExtra`][kirin.analysis.forward.ForwardExtra] instead.
104104
"""
105105

106-
def new_frame(self, code: ir.Statement) -> ForwardFrame[LatticeElemType]:
107-
return ForwardFrame.from_func_like(code)
106+
def initialize_frame(
107+
self, code: ir.Statement, *, has_parent_access: bool = False
108+
) -> ForwardFrame[LatticeElemType]:
109+
return ForwardFrame(code, has_parent_access=has_parent_access)

src/kirin/dialects/cf/constprop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class DialectConstProp(MethodTable):
99

1010
@impl(Branch)
1111
def branch(self, interp: const.Propagate, frame: const.Frame, stmt: Branch):
12-
interp.state.current_frame().worklist.append(
12+
interp.state.current_frame.worklist.append(
1313
Successor(stmt.successor, *frame.get_values(stmt.arguments))
1414
)
1515
return ()
@@ -21,7 +21,7 @@ def conditional_branch(
2121
frame: const.Frame,
2222
stmt: ConditionalBranch,
2323
):
24-
frame = interp.state.current_frame()
24+
frame = interp.state.current_frame
2525
cond = frame.get(stmt.cond)
2626
if isinstance(cond, const.Value):
2727
else_successor = Successor(

src/kirin/dialects/func/emit.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ def emit_function(
1818
self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Function
1919
):
2020
fn_args = stmt.body.blocks[0].args[1:]
21-
argnames = frame.get_values(fn_args)
21+
argnames = tuple(interp.ssa_id[arg] for arg in fn_args)
2222
argtypes = tuple(interp.emit_attribute(x.type) for x in fn_args)
2323
args = [f"{name}::{type}" for name, type in zip(argnames, argtypes)]
2424
interp.write(f"function {stmt.sym_name}({', '.join(args)})")
2525
frame.indent += 1
26-
interp.run_ssacfg_region(frame, stmt.body)
26+
interp.run_ssacfg_region(frame, stmt.body, (stmt.sym_name,) + argnames)
2727
frame.indent -= 1
2828
interp.writeln(frame, "end")
2929
return ()
@@ -63,12 +63,11 @@ def emit_lambda(
6363
self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Lambda
6464
):
6565
args = tuple(interp.ssa_id[x] for x in stmt.body.blocks[0].args[1:])
66-
frame.set_values(stmt.body.blocks[0].args, args)
67-
frame.set_values((stmt.body.blocks[0].args[0],), (stmt.sym_name,))
66+
frame.set_values(stmt.body.blocks[0].args, (stmt.sym_name,) + args)
6867
frame.captured[stmt.body.blocks[0].args[0]] = frame.get_values(stmt.captured)
6968
interp.writeln(frame, f"function {stmt.sym_name}({', '.join(args[1:])})")
7069
frame.indent += 1
71-
interp.run_ssacfg_region(frame, stmt.body)
70+
interp.run_ssacfg_region(frame, stmt.body, args)
7271
frame.indent -= 1
7372
interp.writeln(frame, "end")
7473
return (stmt.sym_name,)

src/kirin/dialects/scf/absint.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ def _infer_if_else_cond(
5656
frame.worklist.append(interp.Successor(body_block, frame.get(stmt.cond)))
5757
return
5858

59-
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
60-
body_frame.entries.update(frame.entries)
61-
body_frame.set(body_block.args[0], frame.get(stmt.cond))
62-
ret = interp_.run_ssacfg_region(body_frame, body)
59+
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
60+
ret = interp_.run_ssacfg_region(body_frame, body, (frame.get(stmt.cond),))
6361
frame.entries.update(body_frame.entries)
64-
return ret
62+
return ret

src/kirin/dialects/scf/constprop.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,9 @@ def _prop_const_cond_ifelse(
8383
cond: const.Value,
8484
body: ir.Region,
8585
):
86-
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
87-
body_frame.entries.update(frame.entries)
88-
body_frame.set(body.blocks[0].args[0], cond)
89-
results = interp_.run_ssacfg_region(body_frame, body)
90-
return body_frame, results
86+
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
87+
results = interp_.run_ssacfg_region(body_frame, body, (cond,))
88+
return body_frame, results
9189

9290
@interp.impl(For)
9391
def for_loop(
@@ -116,17 +114,12 @@ def _prop_const_iterable_forloop(
116114
)
117115

118116
loop_vars = frame.get_values(stmt.initializers)
119-
body_block = stmt.body.blocks[0]
120-
block_args = body_block.args
121117

122118
for value in iterable.data:
123-
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
124-
body_frame.entries.update(frame.entries)
125-
body_frame.set_values(
126-
block_args,
127-
(const.Value(value),) + loop_vars,
119+
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
120+
loop_vars = interp_.run_ssacfg_region(
121+
body_frame, stmt.body, (const.Value(value),) + loop_vars
128122
)
129-
loop_vars = interp_.run_ssacfg_region(body_frame, stmt.body)
130123

131124
if body_frame.frame_is_not_pure:
132125
frame_is_not_pure = True

src/kirin/dialects/scf/interp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@ def if_else(self, interp_: interp.Interpreter, frame: interp.Frame, stmt: IfElse
1818
body = stmt.then_body
1919
else:
2020
body = stmt.else_body
21-
return interp_.run_ssacfg_region(frame, body)
21+
return interp_.run_ssacfg_region(frame, body, (cond,))
2222

2323
@interp.impl(For)
2424
def for_loop(self, interpreter: interp.Interpreter, frame: interp.Frame, stmt: For):
2525
iterable = frame.get(stmt.iterable)
2626
loop_vars = frame.get_values(stmt.initializers)
27-
block_args = stmt.body.blocks[0].args
2827
for value in iterable:
29-
frame.set_values(block_args, (value,) + loop_vars)
30-
loop_vars = interpreter.run_ssacfg_region(frame, stmt.body)
28+
loop_vars = interpreter.run_ssacfg_region(
29+
frame, stmt.body, (value,) + loop_vars
30+
)
3131
if isinstance(loop_vars, interp.ReturnValue):
3232
return loop_vars
3333
elif loop_vars is None:

src/kirin/dialects/scf/typeinfer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@ def for_loop(
4545
frame.worklist.append(interp.Successor(body_block, item, *loop_vars))
4646
return # if terminate is Return, there is no result
4747

48-
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
49-
body_frame.entries.update(frame.entries)
50-
loop_vars_ = interp_.run_ssacfg_region(body_frame, stmt.body)
48+
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
49+
loop_vars_ = interp_.run_ssacfg_region(
50+
body_frame, stmt.body, (iterable,) + loop_vars
51+
)
5152

5253
frame.entries.update(body_frame.entries)
5354
if isinstance(loop_vars_, interp.ReturnValue):

src/kirin/emit/abc.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ class EmitFrame(interp.Frame[ValueType]):
2121
class EmitABC(interp.BaseInterpreter[FrameType, ValueType], ABC):
2222

2323
def run_callable_region(
24-
self, frame: FrameType, code: ir.Statement, region: ir.Region
24+
self,
25+
frame: FrameType,
26+
code: ir.Statement,
27+
region: ir.Region,
28+
args: tuple[ValueType, ...],
2529
) -> ValueType:
2630
results = self.eval_stmt(frame, code)
2731
if isinstance(results, tuple):
@@ -32,12 +36,11 @@ def run_callable_region(
3236
raise interp.InterpreterError(f"Unexpected results {results}")
3337

3438
def run_ssacfg_region(
35-
self, frame: FrameType, region: ir.Region
39+
self, frame: FrameType, region: ir.Region, args: tuple[ValueType, ...]
3640
) -> tuple[ValueType, ...]:
37-
frame.worklist.append(
38-
interp.Successor(region.blocks[0], frame.get_values(region.blocks[0].args))
39-
)
41+
frame.worklist.append(interp.Successor(region.blocks[0], *args))
4042
while (succ := frame.worklist.pop()) is not None:
43+
frame.set_values(succ.block.args, succ.block_args)
4144
block_header = self.emit_block(frame, succ.block)
4245
frame.block_ref[succ.block] = block_header
4346
return ()

src/kirin/emit/abc.pyi

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@ FrameType = TypeVar("FrameType", bound=EmitFrame)
1515

1616
class EmitABC(interp.BaseInterpreter[FrameType, ValueType]):
1717
def run_callable_region(
18-
self, frame: FrameType, code: ir.Statement, region: ir.Region
18+
self,
19+
frame: FrameType,
20+
code: ir.Statement,
21+
region: ir.Region,
22+
args: tuple[ValueType, ...],
1923
) -> ValueType: ...
2024
def run_ssacfg_region(
21-
self, frame: FrameType, region: ir.Region
25+
self, frame: FrameType, region: ir.Region, args: tuple[ValueType, ...]
2226
) -> tuple[ValueType, ...]: ...
2327
def emit_attribute(self, attr: ir.Attribute) -> ValueType: ...
2428
def emit_type_Any(self, attr: types.AnyType) -> ValueType: ...

0 commit comments

Comments
 (0)