Skip to content

Commit f711b84

Browse files
authored
Type checking for #535 (#536)
1 parent 5fd6d12 commit f711b84

File tree

5 files changed

+45
-26
lines changed

5 files changed

+45
-26
lines changed

src/kirin/dialects/scf/scf2cf.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import cast
12
from dataclasses import field, dataclass
23

34
from kirin import ir
@@ -34,7 +35,10 @@ def get_entr_and_exit_blks(self, node: For | IfElse):
3435
result.replace_by(exit_block.args.append_from(result.type, result.name))
3536

3637
curr_block = node.parent_block
37-
assert curr_block.IS_BLOCK, "Node must be inside a block"
38+
assert (
39+
curr_block is not None and curr_block.IS_BLOCK
40+
), "Node must be inside a block"
41+
curr_block = cast(ir.Block, curr_block)
3842

3943
curr_block.stmts.append(
4044
Branch(arguments=(), successor=(entr_block := ir.Block()))
@@ -47,8 +51,12 @@ def get_curr_blk_info(self, node: For | IfElse) -> tuple[ir.Region, int]:
4751
curr_block = node.parent_block
4852
region = node.parent_region
4953

50-
assert region.IS_REGION, "Node must be inside a region"
51-
assert curr_block.IS_BLOCK, "Node must be inside a block"
54+
assert region is not None and region.IS_REGION, "Node must be inside a region"
55+
region = cast(ir.Region, region)
56+
assert (
57+
curr_block is not None and curr_block.IS_BLOCK
58+
), "Node must be inside a block"
59+
curr_block = cast(ir.Block, curr_block)
5260

5361
block_idx = region._block_idx[curr_block]
5462
return region, block_idx

src/kirin/dialects/scf/stmts.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from typing import cast
2+
13
from kirin import ir, types
4+
from kirin.ir import Block, Region
25
from kirin.decl import info, statement
36
from kirin.print.printer import Printer
47

@@ -31,31 +34,29 @@ def __init__(
3134
else_body: ir.Region | ir.Block | None = None,
3235
):
3336
if then_body.IS_REGION:
34-
then_body_region = then_body
37+
then_body_region = cast(Region, then_body)
3538
if then_body_region.blocks:
3639
then_body_block = then_body_region.blocks[-1]
3740
else:
3841
then_body_block = None
39-
elif then_body.IS_BLOCK:
40-
then_body_block = then_body
41-
then_body_region = ir.Region(then_body)
42+
else: # then_body.IS_BLOCK:
43+
then_body_block = cast(Block, then_body)
44+
then_body_region = Region(then_body)
4245

43-
if else_body.IS_REGION:
46+
if else_body is None:
47+
else_body_region = ir.Region()
48+
else_body_block = None
49+
elif else_body.IS_REGION:
50+
else_body_region = cast(Region, else_body)
4451
if not else_body.blocks: # empty region
45-
else_body_region = else_body
4652
else_body_block = None
4753
elif len(else_body.blocks) == 0:
48-
else_body_region = else_body
4954
else_body_block = None
5055
else:
51-
else_body_region = else_body
5256
else_body_block = else_body_region.blocks[0]
53-
elif else_body.IS_BLOCK:
54-
else_body_region = ir.Region(else_body)
57+
else: # else_body.IS_BLOCK:
58+
else_body_region = ir.Region(cast(Block, else_body))
5559
else_body_block = else_body
56-
else:
57-
else_body_region = ir.Region()
58-
else_body_block = None
5960

6061
# if either then or else body has yield, we generate results
6162
# we assume if both have yields, they have the same number of results

src/kirin/ir/exception.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import sys
44
import inspect
55
import textwrap
6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, cast
77

88
from rich.console import Console
99

1010
from kirin.exception import StaticCheckError
1111
from kirin.print.printer import Printer
1212

1313
if TYPE_CHECKING:
14-
from kirin.ir import IRNode, Method
14+
from kirin.ir import IRNode, Method, Statement
1515

1616

1717
class ValidationError(StaticCheckError):
@@ -39,7 +39,8 @@ def attach(self, method: Method):
3939
map(lambda each_line: " " * 4 + each_line, node_str.splitlines())
4040
)
4141
if self.node.IS_STATEMENT:
42-
dialect = self.node.dialect.name if self.node.dialect else "<no dialect>"
42+
stmt = cast("Statement", self.node)
43+
dialect = stmt.dialect.name if stmt.dialect else "<no dialect>"
4344
self.args += (
4445
"when verifying the following statement",
4546
f" `{dialect}.{type(self.node).__name__}` at\n",

src/kirin/lowering/frame.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, Optional, overload
3+
from typing import (
4+
TYPE_CHECKING,
5+
Any,
6+
Generic,
7+
TypeVar,
8+
Callable,
9+
Optional,
10+
cast,
11+
overload,
12+
)
413
from dataclasses import field, dataclass
514

615
from kirin.ir import Block, Region, SSAValue, Statement
@@ -54,9 +63,9 @@ def push(self, node: Block) -> Block: ...
5463

5564
def push(self, node: StmtType | Block) -> StmtType | Block:
5665
if node.IS_BLOCK:
57-
return self._push_block(node)
66+
return self._push_block(cast(Block, node))
5867
elif node.IS_STATEMENT:
59-
return self._push_stmt(node)
68+
return self._push_stmt(cast(Statement, node))
6069
else:
6170
raise BuildError(f"Unsupported type {type(node)} in push()")
6271

src/kirin/rewrite/walk.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable
1+
from typing import Callable, cast
22
from dataclasses import field, dataclass
33

44
from kirin.ir import Block, Region, Statement
@@ -49,11 +49,11 @@ def populate_worklist(self, node: IRNode) -> None:
4949
return
5050

5151
if node.IS_STATEMENT:
52-
self.populate_worklist_Statement(node)
52+
self.populate_worklist_Statement(cast(Statement, node))
5353
elif node.IS_REGION:
54-
self.populate_worklist_Region(node)
54+
self.populate_worklist_Region(cast(Region, node))
5555
elif node.IS_BLOCK:
56-
self.populate_worklist_Block(node)
56+
self.populate_worklist_Block(cast(Block, node))
5757
else:
5858
raise NotImplementedError(f"populate_worklist_{node.__class__.__name__}")
5959

0 commit comments

Comments
 (0)