Skip to content

Commit 076d73b

Browse files
Roger-luoweinbe58
authored andcommitted
fix more stuff
1 parent 9769869 commit 076d73b

File tree

6 files changed

+71
-74
lines changed

6 files changed

+71
-74
lines changed

src/bloqade/qasm2/dialects/inline.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,23 @@ def lower(
3636
"InlineQASM takes a string literal or global string"
3737
)
3838

39+
from kirin.dialects import ilist
40+
41+
from bloqade.qasm2.groups import main
42+
from bloqade.qasm2.dialects import glob, noise, parallel
43+
3944
raw = textwrap.dedent(value)
40-
qasm_lowering = QASM2(state.parent.dialects)
41-
code = qasm_lowering.run(loads(raw))
42-
code.print()
45+
qasm_lowering = QASM2(main.union([ilist, glob, noise, parallel]))
46+
region = qasm_lowering.run(loads(raw))
47+
for qasm_stmt in region.blocks[0].stmts:
48+
qasm_stmt.detach()
49+
state.current_frame.push(qasm_stmt)
50+
51+
for block in region.blocks:
52+
for qasm_stmt in block.stmts:
53+
qasm_stmt.detach()
54+
state.current_frame.push(qasm_stmt)
55+
state.current_frame.jump_next_block()
4356

4457

4558
# NOTE: this is a dummy statement that won't appear in IR.

src/bloqade/qasm2/parse/lowering.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,32 @@ class QASM2(lowering.LoweringABC[ast.Node]):
1515
max_lines: int = field(default=3, kw_only=True)
1616
hint_indent: int = field(default=2, kw_only=True)
1717
hint_show_lineno: bool = field(default=True, kw_only=True)
18-
stacktrace: bool = field(default=False, kw_only=True)
18+
stacktrace: bool = field(default=True, kw_only=True)
1919

2020
def run(
2121
self,
2222
stmt: ast.Node,
2323
*,
24-
state: lowering.State | None = None,
2524
source: str | None = None,
2625
globals: dict[str, Any] | None = None,
2726
file: str | None = None,
2827
lineno_offset: int = 0,
2928
col_offset: int = 0,
3029
compactify: bool = True,
31-
) -> ir.Statement:
30+
) -> ir.Region:
3231
# TODO: add source info
33-
state = state or lowering.State(
32+
state = lowering.State(
3433
self,
3534
file=file,
3635
lineno_offset=lineno_offset,
3736
col_offset=col_offset,
3837
)
39-
with state.frame([stmt], globals=globals) as frame:
38+
with state.frame(
39+
[stmt],
40+
globals=globals,
41+
) as frame:
4042
try:
41-
state.lower(stmt)
43+
self.visit(state, stmt)
4244
except lowering.BuildError as e:
4345
hint = state.error_hint(
4446
e,
@@ -56,22 +58,27 @@ def run(
5658
raise e
5759

5860
region = frame.curr_region
59-
if not region.blocks:
60-
raise ValueError("No block generated")
61-
62-
code = region.blocks[0].first_stmt
63-
if code is None:
64-
raise ValueError("No code generated")
6561

6662
if compactify:
6763
from kirin.rewrite import Walk, CFGCompactify
6864

69-
Walk(CFGCompactify()).rewrite(code)
70-
return code
65+
Walk(CFGCompactify()).rewrite(region)
66+
return region
67+
68+
def visit(self, state: lowering.State[ast.Node], node: ast.Node) -> lowering.Result:
69+
name = node.__class__.__name__
70+
return getattr(self, f"visit_{name}", self.generic_visit)(state, node)
7171

72-
def visit(
72+
def generic_visit(
7373
self, state: lowering.State[ast.Node], node: ast.Node
74-
) -> lowering.Result: ...
74+
) -> lowering.Result:
75+
if isinstance(node, ast.Node):
76+
raise lowering.BuildError(
77+
f"Cannot lower {node.__class__.__name__} node: {node}"
78+
)
79+
raise lowering.BuildError(
80+
f"Unexpected `{node.__class__.__name__}` node: {repr(node)} is not an AST node"
81+
)
7582

7683
def lower_literal(self, state: lowering.State[ast.Node], value) -> ir.SSAValue:
7784
if isinstance(value, int):
@@ -261,7 +268,7 @@ def visit_BinOp(self, state: lowering.State[ast.Node], node: ast.BinOp):
261268
else:
262269
stmt_type = expr.Div
263270

264-
state.current_frame.push(
271+
return state.current_frame.push(
265272
stmt_type(
266273
lhs=state.lower(node.lhs).expect_one(),
267274
rhs=state.lower(node.rhs).expect_one(),
@@ -398,7 +405,8 @@ def visit_Number(self, state: lowering.State[ast.Node], node: ast.Number):
398405
stmt = expr.ConstInt(value=node.value)
399406
else:
400407
stmt = expr.ConstFloat(value=node.value)
401-
return state.current_frame.push(stmt).result
408+
state.current_frame.push(stmt)
409+
return stmt
402410

403411
def visit_Pi(self, state: lowering.State[ast.Node], node: ast.Pi):
404412
return state.current_frame.push(expr.ConstPI()).result

src/bloqade/qasm2/passes/qasm2py.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ class _QASM2Py(RewriteRule):
1717

1818
UNARY_OPS = {
1919
expr.Neg: py.USub,
20-
expr.Sin: math.sin,
21-
expr.Cos: math.cos,
22-
expr.Tan: math.tan,
23-
expr.Exp: math.exp,
24-
expr.Sqrt: math.sqrt,
20+
expr.Sin: math.stmts.sin,
21+
expr.Cos: math.stmts.cos,
22+
expr.Tan: math.stmts.tan,
23+
expr.Exp: math.stmts.exp,
24+
expr.Sqrt: math.stmts.sqrt,
2525
}
2626

2727
BINARY_OPS = {

test/qasm2/passes/test_heuristic_noise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313

1414
class NoiseTestModel(native.MoveNoiseModelABC):
15+
16+
@classmethod
1517
def parallel_cz_errors(cls, ctrls, qargs, rest):
1618
return {(0.01, 0.01, 0.01, 0.01): ctrls + qargs + rest}
1719

test/qasm2/test_inline.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -64,51 +64,3 @@ def qasm2_inline_code():
6464
qasm2.inline(lines)
6565

6666
qasm2_inline_code.print()
67-
68-
69-
if __name__ == "__main__":
70-
# test_inline()
71-
72-
lines = textwrap.dedent(
73-
"""
74-
KIRIN {qasm2.glob, qasm2.uop};
75-
include "qelib1.inc";
76-
77-
qreg q1[2];
78-
qreg q2[3];
79-
80-
glob.U(1.0, 2.0, 3.0) {q1, q2}
81-
"""
82-
)
83-
84-
print(lines)
85-
86-
@qasm2.extended.add(inline)
87-
def qasm2_inline_code():
88-
qasm2.inline(lines)
89-
90-
qasm2_inline_code.print()
91-
92-
93-
lines = textwrap.dedent(
94-
"""
95-
OPENQASM 2.0;
96-
97-
qreg q[2];
98-
creg c[2];
99-
100-
h q[0];
101-
CX q[0], q[1];
102-
barrier q[0], q[1];
103-
CX q[0], q[1];
104-
rx(pi/2) q[0];
105-
"""
106-
)
107-
108-
109-
@qasm2.main.add(inline)
110-
def qasm2_inline_code():
111-
qasm2.inline(lines)
112-
113-
114-
qasm2_inline_code.print()

test/qasm2/test_lowering.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import textwrap
2+
3+
from bloqade import qasm2
4+
from bloqade.qasm2.parse.lowering import QASM2
5+
6+
lines = textwrap.dedent(
7+
"""
8+
OPENQASM 2.0;
9+
10+
qreg q[2];
11+
creg c[2];
12+
13+
h q[0];
14+
CX q[0], q[1];
15+
barrier q[0], q[1];
16+
CX q[0], q[1];
17+
rx(pi/2) q[0];
18+
"""
19+
)
20+
ast = qasm2.parse.loads(lines)
21+
code = QASM2(qasm2.main).run(ast)
22+
code.print()

0 commit comments

Comments
 (0)