Skip to content

Commit 8270794

Browse files
authored
fix: verify + verify_type was not raising correctly (#383)
Thanks to @david-pl 's bug report (see test) cc: @david-pl @kaihsin
1 parent 25020b8 commit 8270794

File tree

5 files changed

+62
-12
lines changed

5 files changed

+62
-12
lines changed

src/kirin/decl/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from kirin.decl.emit.init import EmitInit
1010
from kirin.decl.emit.name import EmitName
1111
from kirin.decl.emit.repr import EmitRepr
12+
from kirin.decl.emit.check import EmitCheck
1213
from kirin.decl.emit.traits import EmitTraits
13-
from kirin.decl.emit.verify import EmitVerify
1414
from kirin.decl.scan_fields import ScanFields
1515
from kirin.decl.emit.dialect import EmitDialect
1616
from kirin.decl.emit.property import EmitProperty
@@ -26,7 +26,7 @@ class StatementDecl(
2626
EmitName,
2727
EmitRepr,
2828
EmitTraits,
29-
EmitVerify,
29+
EmitCheck,
3030
EmitCheckType,
3131
):
3232
pass
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
from ._set_new_attribute import set_new_attribute
88

99

10-
class EmitVerify(BaseModifier):
10+
class EmitCheck(BaseModifier):
1111
_VERIFICATION_ERROR = "_kirin_ValidationError"
1212

13-
def emit_verify(self):
14-
verify_locals: dict[str, Any] = {
13+
def emit_check(self):
14+
check_locals: dict[str, Any] = {
1515
self._VERIFICATION_ERROR: ValidationError,
1616
}
1717
body: list[str] = []
@@ -29,8 +29,8 @@ def emit_verify(self):
2929

3030
if (traits := getattr(self.cls, "traits", None)) is not None:
3131
for trait in traits:
32-
trait_obj = f"_kirin_verify_trait_{trait.__class__.__name__}"
33-
verify_locals.update({trait_obj: trait})
32+
trait_obj = f"_kirin_check_trait_{trait.__class__.__name__}"
33+
check_locals.update({trait_obj: trait})
3434
body.append(f"{trait_obj}.verify({self._self_name})")
3535

3636
# NOTE: we still need to generate this because it is abstract
@@ -39,13 +39,13 @@ def emit_verify(self):
3939

4040
set_new_attribute(
4141
self.cls,
42-
"verify",
42+
"check",
4343
create_fn(
44-
name="_kirin_decl_verify",
44+
name="_kirin_decl_check",
4545
args=[self._self_name],
4646
body=body,
4747
globals=self.globals,
48-
locals=verify_locals,
48+
locals=check_locals,
4949
return_type=None,
5050
),
5151
)

src/kirin/decl/emit/check_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
class EmitCheckType(BaseModifier):
11-
_VERIFICATION_ERROR = "_kirin_IRValidationError"
11+
_VERIFICATION_ERROR = "_kirin_ValidationError"
1212

1313
def emit_check_type(self):
1414
check_type_locals: dict[str, Any] = {

src/kirin/exception.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
from rich.console import Console
1616

17+
from kirin.source import SourceInfo
18+
1719
if TYPE_CHECKING:
1820
from kirin import interp
19-
from kirin.source import SourceInfo
2021

2122
KIRIN_INTERP_STATE = "__kirin_interp_state"
2223
KIRIN_PYTHON_STACKTRACE = os.environ.get("KIRIN_PYTHON_STACKTRACE", "0") == "1"

test/ir/test_verify.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pytest
2+
3+
from kirin import ir, lowering
4+
from kirin.decl import statement
5+
from kirin.prelude import basic_no_opt
6+
7+
dialect = ir.Dialect("foo")
8+
9+
10+
@statement(dialect=dialect)
11+
class InvalidStmt(ir.Statement):
12+
traits = frozenset({lowering.FromPythonCall()})
13+
14+
def check(self):
15+
raise ValueError("Never triggers")
16+
17+
18+
@statement(dialect=dialect)
19+
class InvalidType(ir.Statement):
20+
traits = frozenset({lowering.FromPythonCall()})
21+
22+
def check_type(self):
23+
raise ValueError("Never triggers")
24+
25+
26+
@ir.dialect_group(basic_no_opt.add(dialect))
27+
def foo(self):
28+
def run_pass(mt):
29+
pass
30+
31+
return run_pass
32+
33+
34+
def test_invalid_stmt():
35+
@foo
36+
def test():
37+
InvalidStmt()
38+
39+
with pytest.raises(Exception):
40+
test.verify()
41+
42+
43+
def test_invalid_type():
44+
@foo
45+
def test():
46+
InvalidType()
47+
48+
with pytest.raises(Exception):
49+
test.verify_type()

0 commit comments

Comments
 (0)