Skip to content

Commit 20fd0ad

Browse files
authored
Add back returns kwarg to qasm2.loads / qasm2.loadfile (#266)
1 parent df16cbf commit 20fd0ad

File tree

3 files changed

+88
-12
lines changed

3 files changed

+88
-12
lines changed

src/bloqade/qasm2/_qasm_loading.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pathlib
44
from typing import Any
55

6-
from kirin import ir, types
6+
from kirin import ir, lowering
77
from kirin.dialects import func
88

99
from . import parse
@@ -16,6 +16,7 @@ def loads(
1616
*,
1717
kernel_name: str = "main",
1818
dialects: ir.DialectGroup | None = None,
19+
returns: str | None = None,
1920
globals: dict[str, Any] | None = None,
2021
file: str | None = None,
2122
lineno_offset: int = 0,
@@ -54,7 +55,7 @@ def loads(
5455
# TODO: add source info
5556
stmt = parse.loads(qasm)
5657
qasm2_lowering = QASM2(dialects or main)
57-
body = qasm2_lowering.run(
58+
frame = qasm2_lowering.get_frame(
5859
stmt,
5960
source=qasm,
6061
file=file,
@@ -63,13 +64,21 @@ def loads(
6364
col_offset=col_offset,
6465
compactify=compactify,
6566
)
66-
return_value = func.ConstantNone()
67-
body.blocks[0].stmts.append(return_value)
68-
body.blocks[0].stmts.append(func.Return(value_or_stmt=return_value))
6967

68+
if returns is not None:
69+
return_value = frame.get(returns)
70+
if return_value is None:
71+
raise lowering.BuildError(f"Cannot find return value {returns}")
72+
else:
73+
return_value = func.ConstantNone()
74+
frame.push(return_value)
75+
76+
return_node = frame.push(func.Return(value_or_stmt=return_value))
77+
78+
body = frame.curr_region
7079
code = func.Function(
7180
sym_name=kernel_name,
72-
signature=func.Signature((), types.NoneType),
81+
signature=func.Signature((), return_node.value.type),
7382
body=body,
7483
)
7584

@@ -88,6 +97,7 @@ def loadfile(
8897
*,
8998
kernel_name: str = "main",
9099
dialects: ir.DialectGroup | None = None,
100+
returns: str | None = None,
91101
globals: dict[str, Any] | None = None,
92102
file: str | None = None,
93103
lineno_offset: int = 0,
@@ -132,6 +142,7 @@ def loadfile(
132142
source,
133143
kernel_name=kernel_name,
134144
dialects=dialects,
145+
returns=returns,
135146
globals=globals,
136147
file=file,
137148
lineno_offset=lineno_offset,

src/bloqade/qasm2/parse/lowering.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,28 @@ def run(
2828
col_offset: int = 0,
2929
compactify: bool = True,
3030
) -> ir.Region:
31+
32+
frame = self.get_frame(
33+
stmt,
34+
source=source,
35+
globals=globals,
36+
file=file,
37+
lineno_offset=lineno_offset,
38+
col_offset=col_offset,
39+
)
40+
41+
return frame.curr_region
42+
43+
def get_frame(
44+
self,
45+
stmt: ast.Node,
46+
source: str | None = None,
47+
globals: dict[str, Any] | None = None,
48+
file: str | None = None,
49+
lineno_offset: int = 0,
50+
col_offset: int = 0,
51+
compactify: bool = True,
52+
) -> lowering.Frame:
3153
# TODO: add source info
3254
state = lowering.State(
3355
self,
@@ -41,13 +63,13 @@ def run(
4163
finalize_next=False,
4264
) as frame:
4365
self.visit(state, stmt)
44-
region = frame.curr_region
4566

46-
if compactify:
47-
from kirin.rewrite import Walk, CFGCompactify
67+
if compactify:
68+
from kirin.rewrite import Walk, CFGCompactify
69+
70+
Walk(CFGCompactify()).rewrite(frame.curr_region)
4871

49-
Walk(CFGCompactify()).rewrite(region)
50-
return region
72+
return frame
5173

5274
def visit(self, state: lowering.State[ast.Node], node: ast.Node) -> lowering.Result:
5375
name = node.__class__.__name__

test/pyqrack/test_target.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import math
2+
import textwrap
23

34
import pytest
45
from kirin import ir
56
from kirin.dialects import ilist
67

78
from bloqade import qasm2
8-
from bloqade.pyqrack import PyQrack, PyQrackQubit, StackMemorySimulator, reg
9+
from bloqade.pyqrack import PyQrack, CRegister, PyQrackQubit, StackMemorySimulator, reg
910

1011

1112
def test_target():
@@ -171,3 +172,45 @@ def parallel():
171172
result = target.run(parallel)
172173

173174
assert result == [reg.Measurement.One] * 4
175+
176+
177+
def test_loads_without_return():
178+
qasm2_str = textwrap.dedent(
179+
"""
180+
OPENQASM 2.0;
181+
182+
qreg q[1];
183+
x q[0];
184+
"""
185+
)
186+
187+
main = qasm2.loads(qasm2_str)
188+
189+
sim = StackMemorySimulator(min_qubits=1)
190+
191+
result = sim.run(main)
192+
assert result is None
193+
194+
ket = sim.state_vector(main)
195+
assert ket[0] == 0
196+
197+
198+
def test_loads_with_return():
199+
qasm2_str = textwrap.dedent(
200+
"""
201+
OPENQASM 2.0;
202+
203+
qreg q[1];
204+
creg c[1];
205+
x q[0];
206+
measure q -> c;
207+
"""
208+
)
209+
210+
main = qasm2.loads(qasm2_str, returns="c")
211+
212+
sim = StackMemorySimulator(min_qubits=1)
213+
result = sim.run(main)
214+
215+
assert isinstance(result, CRegister)
216+
assert result[0] == 1

0 commit comments

Comments
 (0)