|
| 1 | +from kirin import ir, types, lowering |
| 2 | +from kirin.decl import info, statement |
| 3 | +from kirin.passes import Fold, TypeInfer |
| 4 | +from kirin.prelude import structural_no_opt |
| 5 | +from kirin.analysis import const |
| 6 | +from kirin.dialects import py, func, ilist |
| 7 | + |
| 8 | +dialect = ir.Dialect("analysis") |
| 9 | + |
| 10 | + |
| 11 | +@statement(dialect=dialect) |
| 12 | +class MultiReturnStatement(ir.Statement): |
| 13 | + traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) |
| 14 | + inputs: tuple[ir.SSAValue] = info.argument(types.Any) |
| 15 | + |
| 16 | + def __init__(self, *args: ir.SSAValue): |
| 17 | + super().__init__( |
| 18 | + args=args, |
| 19 | + result_types=tuple(arg.type for arg in args), |
| 20 | + args_slice={"inputs": slice(None)}, |
| 21 | + ) |
| 22 | + |
| 23 | + |
| 24 | +@ir.dialect_group(structural_no_opt.add(dialect)) |
| 25 | +def dialect_group_test(self): |
| 26 | + fold = Fold(self) |
| 27 | + type_infer = TypeInfer(self) |
| 28 | + |
| 29 | + def run_pass(mt): |
| 30 | + type_infer(mt) |
| 31 | + fold(mt) |
| 32 | + |
| 33 | + return run_pass |
| 34 | + |
| 35 | + |
| 36 | +def test_multi_return_default_prop(): |
| 37 | + |
| 38 | + stmts = [ |
| 39 | + (a := py.Constant(1)), |
| 40 | + (b := py.Constant(2)), |
| 41 | + (res := MultiReturnStatement(a.result, b.result)), |
| 42 | + (return_result := ilist.New((res.results[0], res.results[1]))), |
| 43 | + (func.Return(return_result.result)), |
| 44 | + ] |
| 45 | + |
| 46 | + body = ir.Region(ir.Block(stmts)) |
| 47 | + func_code = func.Function( |
| 48 | + sym_name="test", signature=func.Signature((), types.Any), body=body |
| 49 | + ) |
| 50 | + mt = ir.Method(None, None, "test", (), dialect_group_test, func_code) |
| 51 | + |
| 52 | + frame, return_result = const.Propagate(dialect_group_test).run_analysis( |
| 53 | + mt, no_raise=False |
| 54 | + ) |
| 55 | + |
| 56 | + assert frame.entries[res.results[0]] == const.Unknown() |
| 57 | + assert frame.entries[res.results[1]] == const.Unknown() |
0 commit comments