Skip to content

Commit 8046478

Browse files
authored
Fix Pyright complains from #535 and #536 (#539)
Fix Pyright complains from #535 and #536 (that removed`isinstance` causing performance issue). Casts to correct type so pyright doesn't complain
1 parent f711b84 commit 8046478

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/kirin/dialects/scf/stmts.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,16 @@ def __init__(
4141
then_body_block = None
4242
else: # then_body.IS_BLOCK:
4343
then_body_block = cast(Block, then_body)
44-
then_body_region = Region(then_body)
44+
then_body_region = cast(Region, then_body)
4545

4646
if else_body is None:
4747
else_body_region = ir.Region()
4848
else_body_block = None
4949
elif else_body.IS_REGION:
5050
else_body_region = cast(Region, else_body)
51-
if not else_body.blocks: # empty region
51+
if not else_body_region.blocks: # empty region
5252
else_body_block = None
53-
elif len(else_body.blocks) == 0:
53+
elif len(else_body_region.blocks) == 0:
5454
else_body_block = None
5555
else:
5656
else_body_block = else_body_region.blocks[0]
@@ -63,6 +63,7 @@ def __init__(
6363
results = ()
6464
if then_body_block is not None:
6565
then_yield = then_body_block.last_stmt
66+
else_body_block = cast(Block, else_body_block)
6667
else_yield = (
6768
else_body_block.last_stmt if else_body_block is not None else None
6869
)

src/kirin/lowering/frame.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .state import State
2222

2323
CallbackFn = Callable[["Frame", SSAValue], SSAValue]
24+
StmtType = TypeVar("StmtType", bound=Statement)
2425

2526

2627
@dataclass
@@ -53,8 +54,6 @@ class Frame(Generic[Stmt]):
5354
def __repr__(self):
5455
return f"Frame({len(self.defs)} defs, {len(self.globals)} globals)"
5556

56-
StmtType = TypeVar("StmtType", bound=Statement)
57-
5857
@overload
5958
def push(self, node: StmtType) -> StmtType: ...
6059

@@ -65,7 +64,7 @@ def push(self, node: StmtType | Block) -> StmtType | Block:
6564
if node.IS_BLOCK:
6665
return self._push_block(cast(Block, node))
6766
elif node.IS_STATEMENT:
68-
return self._push_stmt(cast(Statement, node))
67+
return self._push_stmt(cast(StmtType, node))
6968
else:
7069
raise BuildError(f"Unsupported type {type(node)} in push()")
7170

0 commit comments

Comments
 (0)