Skip to content

Commit 6471e8f

Browse files
authored
load + loadfile methods for QASM2 lowering. (#159)
1 parent 6020870 commit 6471e8f

File tree

2 files changed

+142
-3
lines changed

2 files changed

+142
-3
lines changed

src/bloqade/qasm2/parse/lowering.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
import pathlib
13
from typing import Any
24
from dataclasses import field, dataclass
35

@@ -17,6 +19,118 @@ class QASM2(lowering.LoweringABC[ast.Node]):
1719
hint_show_lineno: bool = field(default=True, kw_only=True)
1820
stacktrace: bool = field(default=True, kw_only=True)
1921

22+
def loads(
23+
self,
24+
source: str,
25+
kernel_name: str,
26+
*,
27+
returns: str | None = None,
28+
globals: dict[str, Any] | None = None,
29+
file: str | None = None,
30+
lineno_offset: int = 0,
31+
col_offset: int = 0,
32+
compactify: bool = True,
33+
) -> ir.Method:
34+
from ..parse import loads
35+
36+
# TODO: add source info
37+
stmt = loads(source)
38+
39+
state = lowering.State(
40+
self,
41+
file=file,
42+
lineno_offset=lineno_offset,
43+
col_offset=col_offset,
44+
)
45+
with state.frame(
46+
[stmt],
47+
globals=globals,
48+
) as frame:
49+
try:
50+
self.visit(state, stmt)
51+
# append return statement with the return values
52+
if returns is not None:
53+
return_value = frame.get(returns)
54+
if return_value is None:
55+
raise lowering.BuildError(f"Cannot find return value {returns}")
56+
else:
57+
return_value = func.ConstantNone()
58+
59+
return_node = frame.push(func.Return(value_or_stmt=return_value))
60+
61+
except lowering.BuildError as e:
62+
hint = state.error_hint(
63+
e,
64+
max_lines=self.max_lines,
65+
indent=self.hint_indent,
66+
show_lineno=self.hint_show_lineno,
67+
)
68+
if self.stacktrace:
69+
raise Exception(
70+
f"{e.args[0]}\n\n{hint}",
71+
*e.args[1:],
72+
) from e
73+
else:
74+
e.args = (hint,)
75+
raise e
76+
77+
region = frame.curr_region
78+
79+
if compactify:
80+
from kirin.rewrite import Walk, CFGCompactify
81+
82+
Walk(CFGCompactify()).rewrite(region)
83+
84+
code = func.Function(
85+
sym_name=kernel_name,
86+
signature=func.Signature((), return_node.value.type),
87+
body=region,
88+
)
89+
90+
return ir.Method(
91+
mod=None,
92+
py_func=None,
93+
sym_name=kernel_name,
94+
arg_names=[],
95+
dialects=self.dialects,
96+
code=code,
97+
)
98+
99+
def loadfile(
100+
self,
101+
file: str | pathlib.Path,
102+
*,
103+
kernel_name: str | None = None,
104+
returns: str | None = None,
105+
globals: dict[str, Any] | None = None,
106+
lineno_offset: int = 0,
107+
col_offset: int = 0,
108+
compactify: bool = True,
109+
) -> ir.Method:
110+
if isinstance(file, str):
111+
file = pathlib.Path(*os.path.split(file))
112+
113+
if not file.is_file() or not file.name.endswith(".qasm"):
114+
raise ValueError("File must be a .qasm file")
115+
116+
kernel_name = (
117+
file.name.replace(".qasm", "") if kernel_name is None else kernel_name
118+
)
119+
120+
with file.open("r") as f:
121+
source = f.read()
122+
123+
return self.loads(
124+
source,
125+
kernel_name,
126+
returns=returns,
127+
globals=globals,
128+
file=str(file),
129+
lineno_offset=lineno_offset,
130+
col_offset=col_offset,
131+
compactify=compactify,
132+
)
133+
20134
def run(
21135
self,
22136
stmt: ast.Node,
@@ -85,6 +199,9 @@ def lower_literal(self, state: lowering.State[ast.Node], value) -> ir.SSAValue:
85199
stmt = expr.ConstInt(value=value)
86200
elif isinstance(value, float):
87201
stmt = expr.ConstFloat(value=value)
202+
else:
203+
raise lowering.BuildError(f"Unsupported literal type {type(value)}")
204+
88205
state.current_frame.push(stmt)
89206
return stmt.result
90207

test/qasm2/test_lowering.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import pathlib
2+
import tempfile
13
import textwrap
24

5+
from kirin.dialects import func
6+
37
from bloqade import qasm2
48
from bloqade.qasm2.parse.lowering import QASM2
59

@@ -17,6 +21,24 @@
1721
rx(pi/2) q[0];
1822
"""
1923
)
20-
ast = qasm2.parse.loads(lines)
21-
code = QASM2(qasm2.main).run(ast)
22-
code.print()
24+
25+
26+
def test_run_lowering():
27+
ast = qasm2.parse.loads(lines)
28+
code = QASM2(qasm2.main).run(ast)
29+
code.print()
30+
31+
32+
def test_loadfile():
33+
34+
with tempfile.TemporaryDirectory() as tmp_dir:
35+
with open(f"{tmp_dir}/test.qasm", "w") as f:
36+
f.write(lines)
37+
38+
file = pathlib.Path(f"{tmp_dir}/test.qasm")
39+
kernel = QASM2(qasm2.main).loadfile(file, returns="c")
40+
41+
assert isinstance(
42+
(ret := kernel.callable_region.blocks[0].last_stmt), func.Return
43+
)
44+
assert ret.value.type.is_equal(qasm2.types.CRegType)

0 commit comments

Comments
 (0)