Skip to content

Commit be7ac95

Browse files
authored
rework error handling in compiler (#377)
this PR should be considered non-breaking and let's try to backport it to 0.16 and 0.17 because it recovers the original error behaviour when an error was raised inside the validation or interpreter.
1 parent 0c156be commit be7ac95

File tree

22 files changed

+289
-373
lines changed

22 files changed

+289
-373
lines changed

src/kirin/dialects/py/assertion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ def assert_stmt(
6666
return ()
6767

6868
if stmt.message:
69-
raise interp.WrapException(AssertionError(frame.get(stmt.message)))
69+
raise AssertionError(frame.get(stmt.message))
7070
else:
71-
raise interp.WrapException(AssertionError("Assertion failed"))
71+
raise AssertionError("Assertion failed")
7272

7373

7474
@dialect.register(key="typeinfer")

src/kirin/dialects/py/assign.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,7 @@ def type_assert(self, interp_, frame: interp.Frame, stmt: TypeAssert):
9999
got = frame.get(stmt.got)
100100
got_type = types.PyClass(type(got))
101101
if not got_type.is_subseteq(stmt.expected):
102-
raise interp.WrapException(
103-
TypeError(f"Expected {stmt.expected}, got {got_type}")
104-
)
102+
raise TypeError(f"Expected {stmt.expected}, got {got_type}")
105103
return (frame.get(stmt.got),)
106104

107105

src/kirin/exception.py

Lines changed: 154 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,176 @@
22
for the Kirin-based compilers.
33
"""
44

5-
import abc
5+
from __future__ import annotations
6+
7+
import os
68
import sys
9+
import math
710
import types
8-
9-
stacktrace = False
10-
11-
12-
class NoPythonStackTrace(Exception):
13-
pass
14-
15-
16-
class CustomStackTrace(Exception):
17-
18-
@abc.abstractmethod
19-
def print_stacktrace(self) -> None: ...
11+
import shutil
12+
import textwrap
13+
from typing import TYPE_CHECKING
14+
15+
from rich.console import Console
16+
17+
if TYPE_CHECKING:
18+
from kirin import interp
19+
from kirin.source import SourceInfo
20+
21+
KIRIN_INTERP_STATE = "__kirin_interp_state"
22+
KIRIN_PYTHON_STACKTRACE = os.environ.get("KIRIN_PYTHON_STACKTRACE", "0") == "1"
23+
KIRIN_STATIC_CHECK_LINENO = os.environ.get("KIRIN_STATIC_CHECK_LINENO", "1") == "1"
24+
KIRIN_STATIC_CHECK_INDENT = int(os.environ.get("KIRIN_STATIC_CHECK_INDENT", "2"))
25+
KIRIN_STATIC_CHECK_MAX_LINES = int(os.environ.get("KIRIN_STATIC_CHECK_MAX_LINES", "3"))
26+
27+
28+
class StaticCheckError(Exception):
29+
def __init__(self, *messages: str, help: str | None = None) -> None:
30+
super().__init__(*messages)
31+
self.help: str | None = help
32+
self.source: SourceInfo | None = None
33+
self.lines: list[str] | None = None
34+
self.indent: int = KIRIN_STATIC_CHECK_INDENT
35+
self.max_lines: int = KIRIN_STATIC_CHECK_MAX_LINES
36+
self.show_lineno: bool = KIRIN_STATIC_CHECK_LINENO
37+
38+
def hint(self):
39+
help = self.help or ""
40+
source = self.source or SourceInfo(0, 0, 0, 0)
41+
lines = self.lines or []
42+
begin = max(0, source.lineno - self.max_lines)
43+
end = max(
44+
max(source.lineno + self.max_lines, source.end_lineno or 1),
45+
1,
46+
)
47+
end = min(len(lines), end) # make sure end is within bounds
48+
lines = lines[begin:end]
49+
error_lineno = source.lineno + source.lineno_begin
50+
error_lineno_len = len(str(error_lineno))
51+
code_indent = min(map(self.__get_indent, lines), default=0)
52+
53+
console = Console(force_terminal=True)
54+
with console.capture() as capture:
55+
console.print(
56+
f" {source or 'stdin'}",
57+
markup=True,
58+
highlight=False,
59+
)
60+
for lineno, line in enumerate(lines, begin):
61+
line = " " * self.indent + line[code_indent:]
62+
if self.show_lineno:
63+
if lineno + 1 == source.lineno:
64+
line = f"{error_lineno}[dim]│[/dim]" + line
65+
else:
66+
line = "[dim]" + " " * (error_lineno_len) + "│[/dim]" + line
67+
console.print(" " + line, markup=True, highlight=False)
68+
if lineno + 1 == source.lineno:
69+
console.print(
70+
" "
71+
+ self.__arrow(
72+
source,
73+
code_indent,
74+
error_lineno_len,
75+
help,
76+
self.indent,
77+
self.show_lineno,
78+
),
79+
markup=True,
80+
highlight=False,
81+
)
82+
return capture.get()
83+
84+
def __arrow(
85+
self,
86+
source: SourceInfo,
87+
code_indent: int,
88+
error_lineno_len: int,
89+
help,
90+
indent: int,
91+
show_lineno: bool,
92+
) -> str:
93+
ret = " " * (source.col_offset - code_indent)
94+
if source.end_col_offset:
95+
ret += "^" * (source.end_col_offset - source.col_offset)
96+
else:
97+
ret += "^"
98+
99+
ret = " " * indent + "[red]" + ret
100+
if help:
101+
hint_indent = len(ret) - len("[ret]") + len(" help: ")
102+
terminal_width = math.floor(shutil.get_terminal_size().columns * 0.7)
103+
terminal_width = max(terminal_width - hint_indent, 10)
104+
wrapped = textwrap.fill(str(help), width=terminal_width)
105+
lines = wrapped.splitlines()
106+
ret += " help: " + lines[0] + "[/red]"
107+
for line in lines[1:]:
108+
ret += (
109+
"\n"
110+
+ " " * (error_lineno_len + indent)
111+
+ "[dim]│[/dim]"
112+
+ " " * hint_indent
113+
+ "[red]"
114+
+ line
115+
+ "[/red]"
116+
)
117+
if show_lineno:
118+
ret = " " * error_lineno_len + "[dim]│[/dim]" + ret
119+
return ret
120+
121+
@staticmethod
122+
def __get_indent(line: str) -> int:
123+
if len(line) == 0:
124+
return int(1e9) # very large number
125+
return len(line) - len(line.lstrip())
20126

21127

22128
def enable_stracetrace():
23129
"""Enable the stacktrace for all exceptions."""
24-
global stacktrace
25-
stacktrace = True
130+
global KIRIN_PYTHON_STACKTRACE
131+
KIRIN_PYTHON_STACKTRACE = True
26132

27133

28134
def disable_stracetrace():
29135
"""Disable the stacktrace for all exceptions."""
30-
global stacktrace
31-
stacktrace = False
136+
global KIRIN_PYTHON_STACKTRACE
137+
KIRIN_PYTHON_STACKTRACE = False
138+
139+
140+
def print_stacktrace(exception: Exception, state: interp.InterpreterState):
141+
frame: interp.FrameABC | None = state.current_frame
142+
print(
143+
"==== Python stacktrace has been disabled for simplicity, set KIRIN_PYTHON_STACKTRACE=1 to enable it ===="
144+
)
145+
print(f"{type(exception).__name__}: {exception}", file=sys.stderr)
146+
print("Traceback (most recent call last):", file=sys.stderr)
147+
frames: list[interp.FrameABC] = []
148+
while frame is not None:
149+
frames.append(frame)
150+
frame = frame.parent
151+
frames.reverse()
152+
for frame in frames:
153+
if stmt := frame.current_stmt:
154+
print(" " + repr(stmt.source), file=sys.stderr)
155+
print(" " + stmt.print_str(end=""), file=sys.stderr)
32156

33157

34158
def exception_handler(exc_type, exc_value, exc_tb: types.TracebackType):
35159
"""Custom exception handler to format and print exceptions."""
36-
if not stacktrace and issubclass(exc_type, NoPythonStackTrace):
37-
print("".join(msg for msg in exc_value.args), file=sys.stderr)
160+
if not KIRIN_PYTHON_STACKTRACE and issubclass(exc_type, StaticCheckError):
161+
console = Console(force_terminal=True)
162+
with console.capture() as capture:
163+
console.print(f"[bold red]{exc_type.__name__}:[/bold red]", end="")
164+
print(capture.get(), *exc_value.args, file=sys.stderr)
165+
print("Source Traceback:", file=sys.stderr)
166+
print(exc_value.hint(), file=sys.stderr, end="")
38167
return
39168

40-
if not stacktrace and issubclass(exc_type, CustomStackTrace):
169+
if (
170+
not KIRIN_PYTHON_STACKTRACE
171+
and (state := getattr(exc_value, KIRIN_INTERP_STATE, None)) is not None
172+
):
41173
# Handle custom stack trace exceptions
42-
exc_value.print_stacktrace()
174+
print_stacktrace(exc_value, state)
43175
return
44176

45177
# Call the default exception handler
@@ -51,7 +183,7 @@ def exception_handler(exc_type, exc_value, exc_tb: types.TracebackType):
51183

52184

53185
def custom_exc(shell, etype, evalue, tb, tb_offset=None):
54-
if issubclass(etype, NoPythonStackTrace):
186+
if issubclass(etype, StaticCheckError):
55187
# Handle BuildError exceptions
56188
print(evalue, file=sys.stderr)
57189
return

src/kirin/interp/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .base import BaseInterpreter as BaseInterpreter
1919
from .impl import ImplDef as ImplDef, Signature as Signature, impl as impl
2020
from .frame import Frame as Frame, FrameABC as FrameABC
21+
from .state import InterpreterState as InterpreterState
2122
from .table import MethodTable as MethodTable
2223
from .value import (
2324
Successor as Successor,
@@ -32,8 +33,6 @@
3233
)
3334
from .concrete import Interpreter as Interpreter
3435
from .exceptions import (
35-
WrapException as WrapException,
36-
IntepreterExit as IntepreterExit,
3736
InterpreterError as InterpreterError,
3837
FuelExhaustedError as FuelExhaustedError,
3938
)

src/kirin/interp/base.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616

1717
from typing_extensions import Self, deprecated
1818

19-
from kirin.ir import Block, Region, Statement, DialectGroup, traits
20-
from kirin.ir.method import Method
19+
from kirin.ir import Block, Method, Region, Statement, DialectGroup, traits
20+
from kirin.exception import KIRIN_INTERP_STATE
2121

2222
from .impl import Signature
2323
from .frame import FrameABC
2424
from .state import InterpreterState
2525
from .value import Successor, ReturnValue, SpecialValue, StatementResult
26-
from .exceptions import IntepreterExit, InterpreterError
26+
from .exceptions import InterpreterError
2727

2828
if TYPE_CHECKING:
2929
from kirin.registry import StatementImpl, InterpreterRegistry
@@ -155,9 +155,10 @@ def run(
155155
try:
156156
_, results = self.run_method(mt, args)
157157
except Exception as e:
158-
# NOTE: initialize will create new State
159-
# so we don't need to copy the frames.
160-
raise IntepreterExit(e, self.state) from e
158+
# NOTE: insert the interpreter state into the exception
159+
# so we can print the stack trace
160+
setattr(e, KIRIN_INTERP_STATE, self.state)
161+
raise e
161162
finally:
162163
self._eval_lock = False
163164
sys.setrecursionlimit(current_recursion_limit)
@@ -175,12 +176,10 @@ def run_stmt(
175176
Returns:
176177
StatementResult[ValueType]: the result of the statement.
177178
"""
178-
frame = self.initialize_frame(stmt)
179-
self.state.push_frame(frame)
180-
frame.set_values(stmt.args, args)
181-
results = self.eval_stmt(frame, stmt)
182-
self.state.pop_frame()
183-
return results
179+
with self.new_frame(stmt) as frame:
180+
frame.set_values(stmt.args, args)
181+
results = self.eval_stmt(frame, stmt)
182+
return results
184183

185184
@abstractmethod
186185
def run_method(

src/kirin/interp/exceptions.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,16 @@
11
from __future__ import annotations
22

3-
import sys
4-
from typing import TYPE_CHECKING
5-
from dataclasses import dataclass
6-
7-
from kirin.exception import CustomStackTrace
8-
9-
if TYPE_CHECKING:
10-
from .frame import FrameABC
11-
from .state import InterpreterState
12-
133

144
# errors
155
class InterpreterError(Exception):
166
"""Generic interpreter error.
177
18-
This is the base class for all interpreter errors. Interpreter
19-
errors will be catched by the interpreter and handled appropriately
20-
as an error with stack trace (of Kirin, not Python) from the interpreter.
8+
This is the base class for all interpreter errors.
219
"""
2210

2311
pass
2412

2513

26-
@dataclass
27-
class WrapException(InterpreterError):
28-
"""A special interpreter error that wraps a Python exception."""
29-
30-
exception: Exception
31-
32-
33-
@dataclass
34-
class IntepreterExit(CustomStackTrace):
35-
exception: Exception
36-
state: InterpreterState
37-
38-
def print_stacktrace(self) -> None:
39-
"""Print the stacktrace of the interpreter."""
40-
frame: FrameABC | None = self.state.current_frame
41-
print(f"{type(self.exception).__name__}: {self.exception}", file=sys.stderr)
42-
print("Traceback (most recent call last):", file=sys.stderr)
43-
frames: list[FrameABC] = []
44-
while frame is not None:
45-
frames.append(frame)
46-
frame = frame.parent
47-
frames.reverse()
48-
for frame in frames:
49-
if stmt := frame.current_stmt:
50-
print(" " + repr(stmt.source), file=sys.stderr)
51-
print(" " + stmt.print_str(end=""), file=sys.stderr)
52-
53-
5414
class FuelExhaustedError(InterpreterError):
5515
"""An error raised when the interpreter runs out of fuel."""
5616

src/kirin/ir/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from kirin.ir.attrs.py import PyAttr as PyAttr
4444
from kirin.ir.attrs.abc import Attribute as Attribute, AttributeMeta as AttributeMeta
4545
from kirin.ir.exception import (
46-
HintedError as HintedError,
4746
CompilerError as CompilerError,
4847
TypeCheckError as TypeCheckError,
4948
ValidationError as ValidationError,

0 commit comments

Comments
 (0)