Skip to content

Commit 0160210

Browse files
committed
restricting to only one return value
1 parent da8ecdb commit 0160210

File tree

1 file changed

+10
-25
lines changed

1 file changed

+10
-25
lines changed

src/bloqade/qasm2/parse/lowering.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import field, dataclass
55

66
from kirin import ir, types, lowering
7-
from kirin.dialects import cf, py, func, ilist
7+
from kirin.dialects import cf, func, ilist
88

99
from bloqade.qasm2.types import CRegType, QRegType
1010
from bloqade.qasm2.dialects import uop, core, expr, glob, noise, parallel
@@ -24,7 +24,7 @@ def loads(
2424
source: str,
2525
kernel_name: str,
2626
*,
27-
returns: list[str] | None = None,
27+
returns: str | None = None,
2828
globals: dict[str, Any] | None = None,
2929
file: str | None = None,
3030
lineno_offset: int = 0,
@@ -36,13 +36,6 @@ def loads(
3636
# TODO: add source info
3737
stmt = loads(source)
3838

39-
returns = [] if returns is None else returns
40-
41-
if len(returns) > 1 and py.tuple.dialect not in self.dialects:
42-
raise lowering.BuildError(
43-
"Cannot return multiple values without tuple dialect"
44-
)
45-
4639
state = lowering.State(
4740
self,
4841
file=file,
@@ -56,22 +49,14 @@ def loads(
5649
try:
5750
self.visit(state, stmt)
5851
# append return statement with the return values
59-
values: list[ir.SSAValue] = []
60-
for name in returns:
61-
value = frame.get_local(name)
62-
if value is None:
63-
raise lowering.BuildError(f"Undefined variable {name}")
64-
values.append(value)
65-
66-
match values:
67-
case []:
68-
return_node = frame.push(func.ConstantNone())
69-
case [value]:
70-
return_node = value
71-
case [*values]:
72-
return_node = frame.push(py.tuple.New(values=tuple(values)))
73-
74-
return_node = frame.push(func.Return(value_or_stmt=return_node))
52+
if returns is not None:
53+
return_value = frame.get(returns)
54+
if return_value is None:
55+
raise lowering.BuildError(f"Cannot find return value {returns}")
56+
else:
57+
return_value = func.ConstantNone()
58+
59+
return_node = frame.push(func.Return(value_or_stmt=return_value))
7560

7661
except lowering.BuildError as e:
7762
hint = state.error_hint(

0 commit comments

Comments
 (0)