44from dataclasses import field , dataclass
55
66from kirin import ir , types , lowering
7- from kirin .dialects import cf , py , func , ilist
7+ from kirin .dialects import cf , func , ilist
88
99from bloqade .qasm2 .types import CRegType , QRegType
1010from bloqade .qasm2 .dialects import uop , core , expr , glob , noise , parallel
@@ -24,7 +24,7 @@ def loads(
2424 source : str ,
2525 kernel_name : str ,
2626 * ,
27- returns : list [ str ] | None = None ,
27+ returns : str | None = None ,
2828 globals : dict [str , Any ] | None = None ,
2929 file : str | None = None ,
3030 lineno_offset : int = 0 ,
@@ -36,13 +36,6 @@ def loads(
3636 # TODO: add source info
3737 stmt = loads (source )
3838
39- returns = [] if returns is None else returns
40-
41- if len (returns ) > 1 and py .tuple .dialect not in self .dialects :
42- raise lowering .BuildError (
43- "Cannot return multiple values without tuple dialect"
44- )
45-
4639 state = lowering .State (
4740 self ,
4841 file = file ,
@@ -56,22 +49,14 @@ def loads(
5649 try :
5750 self .visit (state , stmt )
5851 # append return statement with the return values
59- values : list [ir .SSAValue ] = []
60- for name in returns :
61- value = frame .get_local (name )
62- if value is None :
63- raise lowering .BuildError (f"Undefined variable { name } " )
64- values .append (value )
65-
66- match values :
67- case []:
68- return_node = frame .push (func .ConstantNone ())
69- case [value ]:
70- return_node = value
71- case [* values ]:
72- return_node = frame .push (py .tuple .New (values = tuple (values )))
73-
74- return_node = frame .push (func .Return (value_or_stmt = return_node ))
52+ if returns is not None :
53+ return_value = frame .get (returns )
54+ if return_value is None :
55+ raise lowering .BuildError (f"Cannot find return value { returns } " )
56+ else :
57+ return_value = func .ConstantNone ()
58+
59+ return_node = frame .push (func .Return (value_or_stmt = return_value ))
7560
7661 except lowering .BuildError as e :
7762 hint = state .error_hint (
0 commit comments