Skip to content

Commit e7d3379

Browse files
authored
feat: add a builtin option for Pass to not raise (#357)
closes #347 I didn't add it in `unsafe_run` because the `@dataclass` object of pass was designed to store those Pass configurations.
1 parent 4e6d8ca commit e7d3379

File tree

5 files changed

+13
-6
lines changed

5 files changed

+13
-6
lines changed

src/kirin/analysis/typeinfer/analysis.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,15 @@ class TypeInference(Forward[types.TypeAttribute]):
2626
lattice = types.TypeAttribute
2727

2828
def run_analysis(
29-
self, method: ir.Method, args: tuple[types.TypeAttribute, ...] | None = None
29+
self,
30+
method: ir.Method,
31+
args: tuple[types.TypeAttribute, ...] | None = None,
32+
*,
33+
no_raise: bool = True,
3034
) -> tuple[ForwardFrame[types.TypeAttribute], types.TypeAttribute]:
3135
if args is None:
3236
args = method.arg_types
33-
return super().run_analysis(method, args)
37+
return super().run_analysis(method, args, no_raise=no_raise)
3438

3539
# NOTE: unlike concrete interpreter, instead of using type information
3640
# within the IR. Type inference will use the interpreted

src/kirin/passes/abc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from typing import ClassVar
3-
from dataclasses import dataclass
3+
from dataclasses import field, dataclass
44

55
from kirin.ir import Method, DialectGroup
66
from kirin.rewrite.abc import RewriteResult
@@ -24,6 +24,7 @@ class Pass(ABC):
2424

2525
name: ClassVar[str]
2626
dialects: DialectGroup
27+
no_raise: bool = field(default=False, kw_only=True)
2728

2829
def __call__(self, mt: Method) -> RewriteResult:
2930
result = self.unsafe_run(mt)

src/kirin/passes/aggressive/fold.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __post_init__(self):
2828

2929
def unsafe_run(self, mt: Method) -> RewriteResult:
3030
result = RewriteResult()
31-
frame, _ = self.constprop.run_analysis(mt)
31+
frame, _ = self.constprop.run_analysis(mt, no_raise=self.no_raise)
3232
result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
3333
rule = Chain(
3434
ConstantFold(),

src/kirin/passes/fold.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Fold(Pass):
2222

2323
def unsafe_run(self, mt: Method) -> RewriteResult:
2424
constprop = const.Propagate(self.dialects)
25-
frame, _ = constprop.run_analysis(mt)
25+
frame, _ = constprop.run_analysis(mt, no_raise=self.no_raise)
2626
result = Walk(WrapConst(frame)).rewrite(mt.code)
2727
result = (
2828
Fixpoint(

src/kirin/passes/typeinfer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ def __post_init__(self):
1717
self.infer = TypeInference(self.dialects)
1818

1919
def unsafe_run(self, mt: Method) -> RewriteResult:
20-
frame, return_type = self.infer.run_analysis(mt, mt.arg_types)
20+
frame, return_type = self.infer.run_analysis(
21+
mt, mt.arg_types, no_raise=self.no_raise
22+
)
2123
if trait := mt.code.get_trait(HasSignature):
2224
trait.set_signature(mt.code, Signature(mt.arg_types, return_type))
2325

0 commit comments

Comments
 (0)