Skip to content

Commit 5de37c7

Browse files
authored
Merge branch 'main' into phil/ccx-rewrite
2 parents 9d4fbc4 + 89ffbab commit 5de37c7

File tree

4 files changed

+69
-157
lines changed

4 files changed

+69
-157
lines changed

src/bloqade/qasm2/_qasm_loading.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
import os
2+
import logging
3+
import pathlib
14
from typing import Any
25

3-
from kirin import ir
6+
from kirin import ir, types
7+
from kirin.dialects import func
48

9+
from . import parse
510
from .groups import main
611
from .parse.lowering import QASM2
712

@@ -46,19 +51,40 @@ def loads(
4651
''')
4752
```
4853
"""
49-
return QASM2(dialects or main).loads(
50-
qasm,
51-
kernel_name=kernel_name,
52-
globals=globals,
54+
# TODO: add source info
55+
stmt = parse.loads(qasm)
56+
qasm2_lowering = QASM2(dialects or main)
57+
body = qasm2_lowering.run(
58+
stmt,
59+
source=qasm,
5360
file=file,
61+
globals=globals,
5462
lineno_offset=lineno_offset,
5563
col_offset=col_offset,
5664
compactify=compactify,
5765
)
66+
return_value = func.ConstantNone()
67+
body.blocks[0].stmts.append(return_value)
68+
body.blocks[0].stmts.append(func.Return(value_or_stmt=return_value))
69+
70+
code = func.Function(
71+
sym_name=kernel_name,
72+
signature=func.Signature((), types.NoneType),
73+
body=body,
74+
)
75+
76+
return ir.Method(
77+
mod=None,
78+
py_func=None,
79+
sym_name=kernel_name,
80+
arg_names=[],
81+
dialects=qasm2_lowering.dialects,
82+
code=code,
83+
)
5884

5985

6086
def loadfile(
61-
qasm_file: str,
87+
qasm_file: str | pathlib.Path,
6288
*,
6389
kernel_name: str = "main",
6490
dialects: ir.DialectGroup | None = None,
@@ -83,10 +109,27 @@ def loadfile(
83109
col_offset (int): The column number offset for error reporting. Defaults to 0.
84110
compactify (bool): Whether to compactify the output. Defaults to True.
85111
"""
86-
with open(qasm_file, "r") as f:
87-
qasm = f.read()
112+
if isinstance(file, pathlib.Path):
113+
qasm_file_: pathlib.Path = qasm_file # type: ignore
114+
else:
115+
qasm_file_ = pathlib.Path(*os.path.split(qasm_file))
116+
117+
if not qasm_file_.is_file():
118+
raise FileNotFoundError(f"File {qasm_file_} does not exist")
119+
120+
if not qasm_file_.name.endswith(".qasm") or not qasm_file_.name.endswith(".qasm2"):
121+
logging.warning(
122+
f"File {qasm_file_} does not end with .qasm or .qasm2. "
123+
"This may cause issues with loading the file."
124+
)
125+
126+
kernel_name = file.name.replace(".qasm", "") if kernel_name is None else kernel_name
127+
128+
with qasm_file_.open("r") as f:
129+
source = f.read()
130+
88131
return loads(
89-
qasm,
132+
source,
90133
kernel_name=kernel_name,
91134
dialects=dialects,
92135
globals=globals,

src/bloqade/qasm2/parse/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,22 @@ def loadfile(file: str | pathlib.Path):
1919
return loads(f.read())
2020

2121

22-
def pprint(node: ast.Node, *, console: Console | None = None):
22+
def pprint(node: ast.Node, *, console: Console | None = None, no_color: bool = False):
2323
if console:
24-
return Printer(console).visit(node)
24+
printer = Printer(console)
2525
else:
26-
Printer().visit(node)
26+
printer = Printer()
27+
printer.console.no_color = no_color
28+
printer.visit(node)
2729

2830

29-
def spprint(node: ast.Node, *, console: Console | None = None):
31+
def spprint(node: ast.Node, *, console: Console | None = None, no_color: bool = False):
3032
if console:
3133
printer = Printer(console)
3234
else:
3335
printer = Printer()
3436

37+
printer.console.no_color = no_color
3538
with printer.string_io() as stream:
3639
printer.visit(node)
3740
return stream.getvalue()

src/bloqade/qasm2/parse/lowering.py

Lines changed: 2 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
import pathlib
31
from typing import Any
42
from dataclasses import field, dataclass
53

@@ -19,119 +17,6 @@ class QASM2(lowering.LoweringABC[ast.Node]):
1917
hint_show_lineno: bool = field(default=True, kw_only=True)
2018
stacktrace: bool = field(default=True, kw_only=True)
2119

22-
def loads(
23-
self,
24-
source: str,
25-
kernel_name: str,
26-
*,
27-
returns: str | None = None,
28-
globals: dict[str, Any] | None = None,
29-
file: str | None = None,
30-
lineno_offset: int = 0,
31-
col_offset: int = 0,
32-
compactify: bool = True,
33-
) -> ir.Method:
34-
from ..parse import loads
35-
36-
# TODO: add source info
37-
stmt = loads(source)
38-
39-
state = lowering.State(
40-
self,
41-
file=file,
42-
lineno_offset=lineno_offset,
43-
col_offset=col_offset,
44-
)
45-
with state.frame(
46-
[stmt],
47-
globals=globals,
48-
finalize_next=False,
49-
) as frame:
50-
try:
51-
self.visit(state, stmt)
52-
# append return statement with the return values
53-
if returns is not None:
54-
return_value = frame.get(returns)
55-
if return_value is None:
56-
raise lowering.BuildError(f"Cannot find return value {returns}")
57-
else:
58-
return_value = func.ConstantNone()
59-
60-
return_node = frame.push(func.Return(value_or_stmt=return_value))
61-
62-
except lowering.BuildError as e:
63-
hint = state.error_hint(
64-
e,
65-
max_lines=self.max_lines,
66-
indent=self.hint_indent,
67-
show_lineno=self.hint_show_lineno,
68-
)
69-
if self.stacktrace:
70-
raise Exception(
71-
f"{e.args[0]}\n\n{hint}",
72-
*e.args[1:],
73-
) from e
74-
else:
75-
e.args = (hint,)
76-
raise e
77-
78-
region = frame.curr_region
79-
80-
if compactify:
81-
from kirin.rewrite import Walk, CFGCompactify
82-
83-
Walk(CFGCompactify()).rewrite(region)
84-
85-
code = func.Function(
86-
sym_name=kernel_name,
87-
signature=func.Signature((), return_node.value.type),
88-
body=region,
89-
)
90-
91-
return ir.Method(
92-
mod=None,
93-
py_func=None,
94-
sym_name=kernel_name,
95-
arg_names=[],
96-
dialects=self.dialects,
97-
code=code,
98-
)
99-
100-
def loadfile(
101-
self,
102-
file: str | pathlib.Path,
103-
*,
104-
kernel_name: str | None = None,
105-
returns: str | None = None,
106-
globals: dict[str, Any] | None = None,
107-
lineno_offset: int = 0,
108-
col_offset: int = 0,
109-
compactify: bool = True,
110-
) -> ir.Method:
111-
if isinstance(file, str):
112-
file = pathlib.Path(*os.path.split(file))
113-
114-
if not file.is_file() or not file.name.endswith(".qasm"):
115-
raise ValueError("File must be a .qasm file")
116-
117-
kernel_name = (
118-
file.name.replace(".qasm", "") if kernel_name is None else kernel_name
119-
)
120-
121-
with file.open("r") as f:
122-
source = f.read()
123-
124-
return self.loads(
125-
source,
126-
kernel_name,
127-
returns=returns,
128-
globals=globals,
129-
file=str(file),
130-
lineno_offset=lineno_offset,
131-
col_offset=col_offset,
132-
compactify=compactify,
133-
)
134-
13520
def run(
13621
self,
13722
stmt: ast.Node,
@@ -153,25 +38,9 @@ def run(
15338
with state.frame(
15439
[stmt],
15540
globals=globals,
41+
finalize_next=False,
15642
) as frame:
157-
try:
158-
self.visit(state, stmt)
159-
except lowering.BuildError as e:
160-
hint = state.error_hint(
161-
e,
162-
max_lines=self.max_lines,
163-
indent=self.hint_indent,
164-
show_lineno=self.hint_show_lineno,
165-
)
166-
if self.stacktrace:
167-
raise Exception(
168-
f"{e.args[0]}\n\n{hint}",
169-
*e.args[1:],
170-
) from e
171-
else:
172-
e.args = (hint,)
173-
raise e
174-
43+
self.visit(state, stmt)
17544
region = frame.curr_region
17645

17746
if compactify:

test/qasm2/test_lowering.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import tempfile
33
import textwrap
44

5-
from kirin import ir
5+
from kirin import ir, types
66
from kirin.dialects import func
77

88
from bloqade import qasm2
@@ -37,12 +37,7 @@ def test_loadfile():
3737
f.write(lines)
3838

3939
file = pathlib.Path(f"{tmp_dir}/test.qasm")
40-
kernel = QASM2(qasm2.main).loadfile(file, returns="c")
41-
42-
assert isinstance(
43-
(ret := kernel.callable_region.blocks[0].last_stmt), func.Return
44-
)
45-
assert ret.value.type.is_equal(qasm2.types.CRegType)
40+
qasm2.loadfile(file)
4641

4742

4843
def test_negative_lowering():
@@ -54,7 +49,7 @@ def test_negative_lowering():
5449
rz(-0.2) q[0];
5550
"""
5651

57-
entry = QASM2(qasm2.main).loads(mwe, "entry", returns="q")
52+
entry = qasm2.loads(mwe)
5853

5954
body = ir.Region(
6055
ir.Block(
@@ -66,17 +61,19 @@ def test_negative_lowering():
6661
(idx := qasm2.expr.ConstInt(value=0)),
6762
(qubit := qasm2.core.QRegGet(qreg.result, idx.result)),
6863
(qasm2.uop.RZ(qubit.result, theta.result)),
69-
(func.Return(qreg.result)),
64+
(none := func.ConstantNone()),
65+
(func.Return(none.result)),
7066
]
7167
)
7268
)
7369

7470
code = func.Function(
75-
sym_name="entry",
76-
signature=func.Signature((), qasm2.QRegType),
71+
sym_name="main",
72+
signature=func.Signature((), types.NoneType),
7773
body=body,
7874
)
7975

8076
code.print()
77+
entry.print()
8178

8279
assert entry.code.is_structurally_equal(code)

0 commit comments

Comments
 (0)