@@ -15,30 +15,32 @@ class QASM2(lowering.LoweringABC[ast.Node]):
1515 max_lines : int = field (default = 3 , kw_only = True )
1616 hint_indent : int = field (default = 2 , kw_only = True )
1717 hint_show_lineno : bool = field (default = True , kw_only = True )
18- stacktrace : bool = field (default = False , kw_only = True )
18+ stacktrace : bool = field (default = True , kw_only = True )
1919
2020 def run (
2121 self ,
2222 stmt : ast .Node ,
2323 * ,
24- state : lowering .State | None = None ,
2524 source : str | None = None ,
2625 globals : dict [str , Any ] | None = None ,
2726 file : str | None = None ,
2827 lineno_offset : int = 0 ,
2928 col_offset : int = 0 ,
3029 compactify : bool = True ,
31- ) -> ir .Statement :
30+ ) -> ir .Region :
3231 # TODO: add source info
33- state = state or lowering .State (
32+ state = lowering .State (
3433 self ,
3534 file = file ,
3635 lineno_offset = lineno_offset ,
3736 col_offset = col_offset ,
3837 )
39- with state .frame ([stmt ], globals = globals ) as frame :
38+ with state .frame (
39+ [stmt ],
40+ globals = globals ,
41+ ) as frame :
4042 try :
41- state . lower ( stmt )
43+ self . visit ( state , stmt )
4244 except lowering .BuildError as e :
4345 hint = state .error_hint (
4446 e ,
@@ -56,22 +58,27 @@ def run(
5658 raise e
5759
5860 region = frame .curr_region
59- if not region .blocks :
60- raise ValueError ("No block generated" )
61-
62- code = region .blocks [0 ].first_stmt
63- if code is None :
64- raise ValueError ("No code generated" )
6561
6662 if compactify :
6763 from kirin .rewrite import Walk , CFGCompactify
6864
69- Walk (CFGCompactify ()).rewrite (code )
70- return code
65+ Walk (CFGCompactify ()).rewrite (region )
66+ return region
67+
68+ def visit (self , state : lowering .State [ast .Node ], node : ast .Node ) -> lowering .Result :
69+ name = node .__class__ .__name__
70+ return getattr (self , f"visit_{ name } " , self .generic_visit )(state , node )
7171
72- def visit (
72+ def generic_visit (
7373 self , state : lowering .State [ast .Node ], node : ast .Node
74- ) -> lowering .Result : ...
74+ ) -> lowering .Result :
75+ if isinstance (node , ast .Node ):
76+ raise lowering .BuildError (
77+ f"Cannot lower { node .__class__ .__name__ } node: { node } "
78+ )
79+ raise lowering .BuildError (
80+ f"Unexpected `{ node .__class__ .__name__ } ` node: { repr (node )} is not an AST node"
81+ )
7582
7683 def lower_literal (self , state : lowering .State [ast .Node ], value ) -> ir .SSAValue :
7784 if isinstance (value , int ):
@@ -261,7 +268,7 @@ def visit_BinOp(self, state: lowering.State[ast.Node], node: ast.BinOp):
261268 else :
262269 stmt_type = expr .Div
263270
264- state .current_frame .push (
271+ return state .current_frame .push (
265272 stmt_type (
266273 lhs = state .lower (node .lhs ).expect_one (),
267274 rhs = state .lower (node .rhs ).expect_one (),
@@ -398,7 +405,8 @@ def visit_Number(self, state: lowering.State[ast.Node], node: ast.Number):
398405 stmt = expr .ConstInt (value = node .value )
399406 else :
400407 stmt = expr .ConstFloat (value = node .value )
401- return state .current_frame .push (stmt ).result
408+ state .current_frame .push (stmt )
409+ return stmt
402410
403411 def visit_Pi (self , state : lowering .State [ast .Node ], node : ast .Pi ):
404412 return state .current_frame .push (expr .ConstPI ()).result
0 commit comments