Skip to content

Commit 7f4b75a

Browse files
authored
cleanup loads/loadfile in lowering (#242)
follow up #240 I removed the `returns` option cuz it seems no one is using it. We can add something back if there is an actual user asking for it. I'm not sure why it was copy pasting `QASM2.run` method in the old `loads` implementation, maybe there was some issue with the lowering so this was a workaround? @weinbe58
1 parent 7125ea4 commit 7f4b75a

File tree

3 files changed

+62
-153
lines changed

3 files changed

+62
-153
lines changed

src/bloqade/qasm2/_qasm_loading.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
import os
2+
import logging
3+
import pathlib
14
from typing import Any
25

3-
from kirin import ir
6+
from kirin import ir, types
7+
from kirin.dialects import func
48

9+
from . import parse
510
from .groups import main
611
from .parse.lowering import QASM2
712

@@ -46,19 +51,40 @@ def loads(
4651
''')
4752
```
4853
"""
49-
return QASM2(dialects or main).loads(
50-
qasm,
51-
kernel_name=kernel_name,
52-
globals=globals,
54+
# TODO: add source info
55+
stmt = parse.loads(qasm)
56+
qasm2_lowering = QASM2(dialects or main)
57+
body = qasm2_lowering.run(
58+
stmt,
59+
source=qasm,
5360
file=file,
61+
globals=globals,
5462
lineno_offset=lineno_offset,
5563
col_offset=col_offset,
5664
compactify=compactify,
5765
)
66+
return_value = func.ConstantNone()
67+
body.blocks[0].stmts.append(return_value)
68+
body.blocks[0].stmts.append(func.Return(value_or_stmt=return_value))
69+
70+
code = func.Function(
71+
sym_name=kernel_name,
72+
signature=func.Signature((), types.NoneType),
73+
body=body,
74+
)
75+
76+
return ir.Method(
77+
mod=None,
78+
py_func=None,
79+
sym_name=kernel_name,
80+
arg_names=[],
81+
dialects=qasm2_lowering.dialects,
82+
code=code,
83+
)
5884

5985

6086
def loadfile(
61-
qasm_file: str,
87+
qasm_file: str | pathlib.Path,
6288
*,
6389
kernel_name: str = "main",
6490
dialects: ir.DialectGroup | None = None,
@@ -83,10 +109,27 @@ def loadfile(
83109
col_offset (int): The column number offset for error reporting. Defaults to 0.
84110
compactify (bool): Whether to compactify the output. Defaults to True.
85111
"""
86-
with open(qasm_file, "r") as f:
87-
qasm = f.read()
112+
if isinstance(file, pathlib.Path):
113+
qasm_file_: pathlib.Path = qasm_file # type: ignore
114+
else:
115+
qasm_file_ = pathlib.Path(*os.path.split(qasm_file))
116+
117+
if not qasm_file_.is_file():
118+
raise FileNotFoundError(f"File {qasm_file_} does not exist")
119+
120+
if not qasm_file_.name.endswith(".qasm") or not qasm_file_.name.endswith(".qasm2"):
121+
logging.warning(
122+
f"File {qasm_file_} does not end with .qasm or .qasm2. "
123+
"This may cause issues with loading the file."
124+
)
125+
126+
kernel_name = file.name.replace(".qasm", "") if kernel_name is None else kernel_name
127+
128+
with qasm_file_.open("r") as f:
129+
source = f.read()
130+
88131
return loads(
89-
qasm,
132+
source,
90133
kernel_name=kernel_name,
91134
dialects=dialects,
92135
globals=globals,

src/bloqade/qasm2/parse/lowering.py

Lines changed: 2 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
import pathlib
31
from typing import Any
42
from dataclasses import field, dataclass
53

@@ -19,119 +17,6 @@ class QASM2(lowering.LoweringABC[ast.Node]):
1917
hint_show_lineno: bool = field(default=True, kw_only=True)
2018
stacktrace: bool = field(default=True, kw_only=True)
2119

22-
def loads(
23-
self,
24-
source: str,
25-
kernel_name: str,
26-
*,
27-
returns: str | None = None,
28-
globals: dict[str, Any] | None = None,
29-
file: str | None = None,
30-
lineno_offset: int = 0,
31-
col_offset: int = 0,
32-
compactify: bool = True,
33-
) -> ir.Method:
34-
from ..parse import loads
35-
36-
# TODO: add source info
37-
stmt = loads(source)
38-
39-
state = lowering.State(
40-
self,
41-
file=file,
42-
lineno_offset=lineno_offset,
43-
col_offset=col_offset,
44-
)
45-
with state.frame(
46-
[stmt],
47-
globals=globals,
48-
finalize_next=False,
49-
) as frame:
50-
try:
51-
self.visit(state, stmt)
52-
# append return statement with the return values
53-
if returns is not None:
54-
return_value = frame.get(returns)
55-
if return_value is None:
56-
raise lowering.BuildError(f"Cannot find return value {returns}")
57-
else:
58-
return_value = func.ConstantNone()
59-
60-
return_node = frame.push(func.Return(value_or_stmt=return_value))
61-
62-
except lowering.BuildError as e:
63-
hint = state.error_hint(
64-
e,
65-
max_lines=self.max_lines,
66-
indent=self.hint_indent,
67-
show_lineno=self.hint_show_lineno,
68-
)
69-
if self.stacktrace:
70-
raise Exception(
71-
f"{e.args[0]}\n\n{hint}",
72-
*e.args[1:],
73-
) from e
74-
else:
75-
e.args = (hint,)
76-
raise e
77-
78-
region = frame.curr_region
79-
80-
if compactify:
81-
from kirin.rewrite import Walk, CFGCompactify
82-
83-
Walk(CFGCompactify()).rewrite(region)
84-
85-
code = func.Function(
86-
sym_name=kernel_name,
87-
signature=func.Signature((), return_node.value.type),
88-
body=region,
89-
)
90-
91-
return ir.Method(
92-
mod=None,
93-
py_func=None,
94-
sym_name=kernel_name,
95-
arg_names=[],
96-
dialects=self.dialects,
97-
code=code,
98-
)
99-
100-
def loadfile(
101-
self,
102-
file: str | pathlib.Path,
103-
*,
104-
kernel_name: str | None = None,
105-
returns: str | None = None,
106-
globals: dict[str, Any] | None = None,
107-
lineno_offset: int = 0,
108-
col_offset: int = 0,
109-
compactify: bool = True,
110-
) -> ir.Method:
111-
if isinstance(file, str):
112-
file = pathlib.Path(*os.path.split(file))
113-
114-
if not file.is_file() or not file.name.endswith(".qasm"):
115-
raise ValueError("File must be a .qasm file")
116-
117-
kernel_name = (
118-
file.name.replace(".qasm", "") if kernel_name is None else kernel_name
119-
)
120-
121-
with file.open("r") as f:
122-
source = f.read()
123-
124-
return self.loads(
125-
source,
126-
kernel_name,
127-
returns=returns,
128-
globals=globals,
129-
file=str(file),
130-
lineno_offset=lineno_offset,
131-
col_offset=col_offset,
132-
compactify=compactify,
133-
)
134-
13520
def run(
13621
self,
13722
stmt: ast.Node,
@@ -153,25 +38,9 @@ def run(
15338
with state.frame(
15439
[stmt],
15540
globals=globals,
41+
finalize_next=False,
15642
) as frame:
157-
try:
158-
self.visit(state, stmt)
159-
except lowering.BuildError as e:
160-
hint = state.error_hint(
161-
e,
162-
max_lines=self.max_lines,
163-
indent=self.hint_indent,
164-
show_lineno=self.hint_show_lineno,
165-
)
166-
if self.stacktrace:
167-
raise Exception(
168-
f"{e.args[0]}\n\n{hint}",
169-
*e.args[1:],
170-
) from e
171-
else:
172-
e.args = (hint,)
173-
raise e
174-
43+
self.visit(state, stmt)
17544
region = frame.curr_region
17645

17746
if compactify:

test/qasm2/test_lowering.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import tempfile
33
import textwrap
44

5-
from kirin import ir
5+
from kirin import ir, types
66
from kirin.dialects import func
77

88
from bloqade import qasm2
@@ -37,12 +37,7 @@ def test_loadfile():
3737
f.write(lines)
3838

3939
file = pathlib.Path(f"{tmp_dir}/test.qasm")
40-
kernel = QASM2(qasm2.main).loadfile(file, returns="c")
41-
42-
assert isinstance(
43-
(ret := kernel.callable_region.blocks[0].last_stmt), func.Return
44-
)
45-
assert ret.value.type.is_equal(qasm2.types.CRegType)
40+
qasm2.loadfile(file)
4641

4742

4843
def test_negative_lowering():
@@ -54,7 +49,7 @@ def test_negative_lowering():
5449
rz(-0.2) q[0];
5550
"""
5651

57-
entry = QASM2(qasm2.main).loads(mwe, "entry", returns="q")
52+
entry = qasm2.loads(mwe)
5853

5954
body = ir.Region(
6055
ir.Block(
@@ -66,17 +61,19 @@ def test_negative_lowering():
6661
(idx := qasm2.expr.ConstInt(value=0)),
6762
(qubit := qasm2.core.QRegGet(qreg.result, idx.result)),
6863
(qasm2.uop.RZ(qubit.result, theta.result)),
69-
(func.Return(qreg.result)),
64+
(none := func.ConstantNone()),
65+
(func.Return(none.result)),
7066
]
7167
)
7268
)
7369

7470
code = func.Function(
75-
sym_name="entry",
76-
signature=func.Signature((), qasm2.QRegType),
71+
sym_name="main",
72+
signature=func.Signature((), types.NoneType),
7773
body=body,
7874
)
7975

8076
code.print()
77+
entry.print()
8178

8279
assert entry.code.is_structurally_equal(code)

0 commit comments

Comments
 (0)