diff --git a/src/kirin/dialects/scf/stmts.py b/src/kirin/dialects/scf/stmts.py index 3c4ea008f..6348b8b5d 100644 --- a/src/kirin/dialects/scf/stmts.py +++ b/src/kirin/dialects/scf/stmts.py @@ -41,7 +41,7 @@ def __init__( then_body_block = None else: # then_body.IS_BLOCK: then_body_block = cast(Block, then_body) - then_body_region = cast(Region, then_body) + then_body_region = Region(then_body_block) if else_body is None: else_body_region = ir.Region() diff --git a/test/dialects/scf/test_ifelse.py b/test/dialects/scf/test_ifelse.py index c5cdcf906..86f35dbda 100644 --- a/test/dialects/scf/test_ifelse.py +++ b/test/dialects/scf/test_ifelse.py @@ -3,7 +3,7 @@ from kirin import ir from kirin.passes import Fold from kirin.prelude import python_basic -from kirin.dialects import scf, func, lowering +from kirin.dialects import py, scf, func, lowering # TODO: # test_cons @@ -150,3 +150,19 @@ def main(n: int): assert main(2) == 1.0 assert main(1) == 1.0 assert main(0) == 0.0 + + +def test_manual_construct_ifelse_from_blocks(): + scf.IfElse( + cond=ir.TestValue(), + then_body=ir.Block( + stmts=[ + py.Constant(5), + ], + ), + else_body=ir.Block( + stmts=[ + py.Constant(11), + ], + ), + )