Skip to content

Commit 4788445

Browse files
authored
Merge branch 'main' into david/pyright-athon
2 parents 07c6380 + 4ebc574 commit 4788445

File tree

3 files changed

+176
-4
lines changed

3 files changed

+176
-4
lines changed

src/bloqade/qasm2/parse/lowering.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
import pathlib
13
from typing import Any
24
from dataclasses import field, dataclass
35

@@ -17,6 +19,118 @@ class QASM2(lowering.LoweringABC[ast.Node]):
1719
hint_show_lineno: bool = field(default=True, kw_only=True)
1820
stacktrace: bool = field(default=True, kw_only=True)
1921

22+
def loads(
23+
self,
24+
source: str,
25+
kernel_name: str,
26+
*,
27+
returns: str | None = None,
28+
globals: dict[str, Any] | None = None,
29+
file: str | None = None,
30+
lineno_offset: int = 0,
31+
col_offset: int = 0,
32+
compactify: bool = True,
33+
) -> ir.Method:
34+
from ..parse import loads
35+
36+
# TODO: add source info
37+
stmt = loads(source)
38+
39+
state = lowering.State(
40+
self,
41+
file=file,
42+
lineno_offset=lineno_offset,
43+
col_offset=col_offset,
44+
)
45+
with state.frame(
46+
[stmt],
47+
globals=globals,
48+
) as frame:
49+
try:
50+
self.visit(state, stmt)
51+
# append return statement with the return values
52+
if returns is not None:
53+
return_value = frame.get(returns)
54+
if return_value is None:
55+
raise lowering.BuildError(f"Cannot find return value {returns}")
56+
else:
57+
return_value = func.ConstantNone()
58+
59+
return_node = frame.push(func.Return(value_or_stmt=return_value))
60+
61+
except lowering.BuildError as e:
62+
hint = state.error_hint(
63+
e,
64+
max_lines=self.max_lines,
65+
indent=self.hint_indent,
66+
show_lineno=self.hint_show_lineno,
67+
)
68+
if self.stacktrace:
69+
raise Exception(
70+
f"{e.args[0]}\n\n{hint}",
71+
*e.args[1:],
72+
) from e
73+
else:
74+
e.args = (hint,)
75+
raise e
76+
77+
region = frame.curr_region
78+
79+
if compactify:
80+
from kirin.rewrite import Walk, CFGCompactify
81+
82+
Walk(CFGCompactify()).rewrite(region)
83+
84+
code = func.Function(
85+
sym_name=kernel_name,
86+
signature=func.Signature((), return_node.value.type),
87+
body=region,
88+
)
89+
90+
return ir.Method(
91+
mod=None,
92+
py_func=None,
93+
sym_name=kernel_name,
94+
arg_names=[],
95+
dialects=self.dialects,
96+
code=code,
97+
)
98+
99+
def loadfile(
100+
self,
101+
file: str | pathlib.Path,
102+
*,
103+
kernel_name: str | None = None,
104+
returns: str | None = None,
105+
globals: dict[str, Any] | None = None,
106+
lineno_offset: int = 0,
107+
col_offset: int = 0,
108+
compactify: bool = True,
109+
) -> ir.Method:
110+
if isinstance(file, str):
111+
file = pathlib.Path(*os.path.split(file))
112+
113+
if not file.is_file() or not file.name.endswith(".qasm"):
114+
raise ValueError("File must be a .qasm file")
115+
116+
kernel_name = (
117+
file.name.replace(".qasm", "") if kernel_name is None else kernel_name
118+
)
119+
120+
with file.open("r") as f:
121+
source = f.read()
122+
123+
return self.loads(
124+
source,
125+
kernel_name,
126+
returns=returns,
127+
globals=globals,
128+
file=str(file),
129+
lineno_offset=lineno_offset,
130+
col_offset=col_offset,
131+
compactify=compactify,
132+
)
133+
20134
def run(
21135
self,
22136
stmt: ast.Node,

test/qasm2/test_lowering.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import pathlib
2+
import tempfile
13
import textwrap
24

5+
from kirin.dialects import func
6+
37
from bloqade import qasm2
48
from bloqade.qasm2.parse.lowering import QASM2
59

@@ -17,6 +21,24 @@
1721
rx(pi/2) q[0];
1822
"""
1923
)
20-
ast = qasm2.parse.loads(lines)
21-
code = QASM2(qasm2.main).run(ast)
22-
code.print()
24+
25+
26+
def test_run_lowering():
27+
ast = qasm2.parse.loads(lines)
28+
code = QASM2(qasm2.main).run(ast)
29+
code.print()
30+
31+
32+
def test_loadfile():
33+
34+
with tempfile.TemporaryDirectory() as tmp_dir:
35+
with open(f"{tmp_dir}/test.qasm", "w") as f:
36+
f.write(lines)
37+
38+
file = pathlib.Path(f"{tmp_dir}/test.qasm")
39+
kernel = QASM2(qasm2.main).loadfile(file, returns="c")
40+
41+
assert isinstance(
42+
(ret := kernel.callable_region.blocks[0].last_stmt), func.Return
43+
)
44+
assert ret.value.type.is_equal(qasm2.types.CRegType)

test/qasm2/test_native.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import textwrap
33

44
import cirq
5-
import cirq.contrib
65
import cirq.testing
76
import cirq.contrib.qasm_import as qasm_import
87
import cirq.circuits.qasm_output as qasm_output
@@ -86,6 +85,18 @@ def generator(n_tests: int):
8685
"""
8786
)
8887

88+
yield textwrap.dedent(
89+
"""
90+
OPENQASM 2.0;
91+
include "qelib1.inc";
92+
93+
qreg q[2];
94+
95+
cu3(0.0, 0.6, 3.141591) q[0],q[1];
96+
97+
"""
98+
)
99+
89100
rgen = np.random.RandomState(128)
90101
for num in range(n_tests):
91102
# Generate a new instance:
@@ -117,3 +128,28 @@ def kernel():
117128
cirq.testing.assert_allclose_up_to_global_phase(
118129
cirq.unitary(old_circuit), cirq.unitary(cirq_circuit), atol=1e-8
119130
)
131+
132+
133+
def test_cu3_rewrite():
134+
prog = textwrap.dedent(
135+
"""
136+
OPENQASM 2.0;
137+
include "qelib1.inc";
138+
139+
qreg q[2];
140+
141+
cu3(0.0, 0.6, 3.141591) q[0],q[1];
142+
143+
"""
144+
)
145+
146+
@qasm2.main.add(qasm2.dialects.inline)
147+
def kernel():
148+
qasm2.inline(prog)
149+
150+
walk.Walk(RydbergGateSetRewriteRule(kernel.dialects)).rewrite(kernel.code)
151+
152+
new_qasm2 = qasm2.emit.QASM2().emit_str(kernel)
153+
154+
# simple-stupid test to see if the rewrite injected a bunch of new lines
155+
assert new_qasm2.count("\n") > prog.count("\n")

0 commit comments

Comments
 (0)