Skip to content

Commit 7c01f87

Browse files
authored
support stacktrace in interpreter (#374)
```python from kirin.prelude import basic @basic(typeinfer=True) def some_code_will_error(x): """ This function will raise an error. """ return 1 / x @basic(typeinfer=True) def some_func(): return some_code_will_error(0) some_func() ``` <img width="839" alt="image" src="https://github.com/user-attachments/assets/031700a6-87f7-4d72-bc54-a11871c66e64" />
1 parent 85f0e7d commit 7c01f87

File tree

16 files changed

+101
-45
lines changed

16 files changed

+101
-45
lines changed

src/kirin/analysis/forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def run_analysis(
6464
sys.setrecursionlimit(self.max_python_recursion_depth)
6565
try:
6666
frame, ret = self.run_method(method, args)
67-
except interp.InterpreterError as e:
67+
except Exception as e:
6868
# NOTE: initialize will create new State
6969
# so we don't need to copy the frames.
7070
if not no_raise:

src/kirin/exception.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
for the Kirin-based compilers.
33
"""
44

5+
import abc
56
import sys
67
import types
78

@@ -12,6 +13,12 @@ class NoPythonStackTrace(Exception):
1213
pass
1314

1415

16+
class CustomStackTrace(Exception):
17+
18+
@abc.abstractmethod
19+
def print_stacktrace(self) -> None: ...
20+
21+
1522
def enable_stracetrace():
1623
"""Enable the stacktrace for all exceptions."""
1724
global stacktrace
@@ -30,6 +37,11 @@ def exception_handler(exc_type, exc_value, exc_tb: types.TracebackType):
3037
print(exc_value, file=sys.stderr)
3138
return
3239

40+
if not stacktrace and issubclass(exc_type, CustomStackTrace):
41+
# Handle custom stack trace exceptions
42+
exc_value.print_stacktrace()
43+
return
44+
3345
# Call the default exception handler
3446
sys.__excepthook__(exc_type, exc_value, exc_tb)
3547

src/kirin/interp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from .concrete import Interpreter as Interpreter
3535
from .exceptions import (
3636
WrapException as WrapException,
37+
IntepreterExit as IntepreterExit,
3738
InterpreterError as InterpreterError,
3839
FuelExhaustedError as FuelExhaustedError,
3940
)

src/kirin/interp/base.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
from .frame import FrameABC
2424
from .state import InterpreterState
2525
from .value import Successor, ReturnValue, SpecialValue, StatementResult
26-
from .result import Ok, Err, Result
27-
from .exceptions import InterpreterError
26+
from .exceptions import IntepreterExit, InterpreterError
2827

2928
if TYPE_CHECKING:
3029
from kirin.registry import StatementImpl, InterpreterRegistry
@@ -132,7 +131,7 @@ def run(
132131
mt: Method,
133132
args: tuple[ValueType, ...],
134133
kwargs: dict[str, ValueType] | None = None,
135-
) -> Result[ValueType]:
134+
) -> ValueType:
136135
"""Run a method. This is the main entry point of the interpreter.
137136
138137
Args:
@@ -155,14 +154,14 @@ def run(
155154
args = self.get_args(mt.arg_names[len(args) + 1 :], args, kwargs)
156155
try:
157156
_, results = self.run_method(mt, args)
158-
except InterpreterError as e:
157+
except Exception as e:
159158
# NOTE: initialize will create new State
160159
# so we don't need to copy the frames.
161-
return Err(e, self.state)
160+
raise IntepreterExit(e, self.state) from e
162161
finally:
163162
self._eval_lock = False
164163
sys.setrecursionlimit(current_recursion_limit)
165-
return Ok(results)
164+
return results
166165

167166
def run_stmt(
168167
self, stmt: Statement, args: tuple[ValueType, ...]

src/kirin/interp/exceptions.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from typing import TYPE_CHECKING
15
from dataclasses import dataclass
26

7+
from kirin.exception import CustomStackTrace
8+
9+
if TYPE_CHECKING:
10+
from .frame import FrameABC
11+
from .state import InterpreterState
12+
313

414
# errors
515
class InterpreterError(Exception):
@@ -20,6 +30,27 @@ class WrapException(InterpreterError):
2030
exception: Exception
2131

2232

33+
@dataclass
34+
class IntepreterExit(CustomStackTrace):
35+
exception: Exception
36+
state: InterpreterState
37+
38+
def print_stacktrace(self) -> None:
39+
"""Print the stacktrace of the interpreter."""
40+
frame: FrameABC | None = self.state.current_frame
41+
print(f"{type(self.exception).__name__}: {self.exception}", file=sys.stderr)
42+
print("Traceback (most recent call last):", file=sys.stderr)
43+
frames: list[FrameABC] = []
44+
while frame is not None:
45+
frames.append(frame)
46+
frame = frame.parent
47+
frames.reverse()
48+
for frame in frames:
49+
if stmt := frame.current_stmt:
50+
print(" " + repr(stmt.source), file=sys.stderr)
51+
print(" " + stmt.print_str(end=""), file=sys.stderr)
52+
53+
2354
class FuelExhaustedError(InterpreterError):
2455
"""An error raised when the interpreter runs out of fuel."""
2556

src/kirin/interp/frame.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class FrameABC(ABC, Generic[ValueType]):
3434
has_parent_access: bool = field(default=False, kw_only=True, compare=True)
3535
"""If we have access to the entries of the parent frame."""
3636

37+
lineno_offset: int = field(default=0, kw_only=True, compare=True)
38+
3739
current_stmt: Statement | None = field(
3840
default=None, init=False, compare=False, repr=False
3941
)

src/kirin/ir/group.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,15 @@ def __call__(
190190
Returns:
191191
Method: the method created from the python function.
192192
"""
193+
frame = inspect.currentframe()
193194

194195
def wrapper(py_func: Callable) -> Method:
195196
if py_func.__name__ == "<lambda>":
196197
raise ValueError("Cannot compile lambda functions")
197198

198199
lineno_offset, file = 0, ""
199-
frame = inspect.currentframe()
200-
if frame and frame.f_back is not None and frame.f_back.f_back is not None:
201-
call_site_frame = frame.f_back.f_back
200+
if frame and frame.f_back is not None:
201+
call_site_frame = frame.f_back
202202
if py_func.__name__ in call_site_frame.f_locals:
203203
raise CompilerError(
204204
f"overwriting function definition of `{py_func.__name__}`"

src/kirin/ir/method.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __call__(self, *args: Param.args, **kwargs: Param.kwargs) -> RetType:
5050
raise ValueError("Incorrect number of arguments")
5151
# NOTE: multi-return values will be wrapped in a tuple for Python
5252
interp = Interpreter(self.dialects)
53-
return interp.run(self, args=args, kwargs=kwargs).expect()
53+
return interp.run(self, args=args, kwargs=kwargs)
5454

5555
@property
5656
def args(self):
@@ -141,7 +141,6 @@ def __postprocess_validation_error(self, e: ValidationError):
141141
source.splitlines(),
142142
e,
143143
file=self.file,
144-
lineno_offset=self.lineno_offset,
145144
)
146145
else:
147146
msg += "\nNo source available"

src/kirin/lowering/python/lowering.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def run(
112112
source = source or ast.unparse(stmt)
113113
state = State(
114114
self,
115-
source=SourceInfo.from_ast(stmt, lineno_offset, col_offset),
115+
source=SourceInfo.from_ast(stmt, lineno_offset, col_offset, file),
116116
file=file,
117117
lines=source.splitlines(),
118118
lineno_offset=lineno_offset,
@@ -156,7 +156,7 @@ def lower_global(self, state: State[ast.AST], node: ast.AST) -> LoweringABC.Resu
156156
def visit(self, state: State[ast.AST], node: ast.AST) -> Result:
157157
if hasattr(node, "lineno"):
158158
state.source = SourceInfo.from_ast(
159-
node, state.lineno_offset, state.col_offset
159+
node, state.lineno_offset, state.col_offset, state.file
160160
)
161161
name = node.__class__.__name__
162162
if name in self.registry.ast_table:
@@ -169,7 +169,7 @@ def generic_visit(self, state: State[ast.AST], node: ast.AST) -> Result:
169169
def visit_Call(self, state: State[ast.AST], node: ast.Call) -> Result:
170170
if hasattr(node.func, "lineno"):
171171
state.source = SourceInfo.from_ast(
172-
node.func, state.lineno_offset, state.col_offset
172+
node.func, state.lineno_offset, state.col_offset, state.file
173173
)
174174

175175
global_callee_result = state.get_global(node.func, no_raise=True)

src/kirin/lowering/state.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,5 +228,4 @@ def error_hint(
228228
indent=indent,
229229
show_lineno=show_lineno,
230230
max_lines=max_lines,
231-
lineno_offset=self.lineno_offset,
232231
)

0 commit comments

Comments
 (0)