Skip to content

Commit 99e66af

Browse files
committed
get rid of redundant rewrite of QASM2 const/expr to py const/expr - already exists
1 parent 59dfcb0 commit 99e66af

File tree

6 files changed

+40
-189
lines changed

6 files changed

+40
-189
lines changed

src/bloqade/squin/passes/qasm2_gate_func_to_squin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from kirin.rewrite.abc import RewriteRule, RewriteResult
55

66
from bloqade.rewrite.passes import CallGraphPass
7+
from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule
78

89
from ..rewrite import qasm2 as qasm2_rule
910

@@ -37,7 +38,7 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
3738

3839
combined_qasm2_rules = Walk(
3940
Chain(
40-
qasm2_rule.QASM2ExprToSquin(),
41+
QASM2ToPyRule(),
4142
qasm2_rule.QASM2CoreToSquin(),
4243
qasm2_rule.QASM2UOPToSquin(),
4344
qasm2_rule.QASM2NoiseToSquin(),

src/bloqade/squin/passes/qasm2_to_squin.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
from bloqade.squin.rewrite.qasm2 import (
1111
QASM2UOPToSquin,
1212
QASM2CoreToSquin,
13-
QASM2ExprToSquin,
1413
QASM2NoiseToSquin,
1514
QASM2GlobParallelToSquin,
1615
)
1716

17+
# There's a QASM2Py pass that only applies an _QASM2Py rewrite rule,
18+
# I just want the rule here.
19+
from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule
20+
1821
from .qasm2_gate_func_to_squin import QASM2GateFuncToSquinPass
1922

2023

@@ -26,7 +29,7 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
2629
# rewrite all QASM2 to squin first
2730
rewrite_result = Walk(
2831
Chain(
29-
QASM2ExprToSquin(),
32+
QASM2ToPyRule(),
3033
QASM2CoreToSquin(),
3134
QASM2UOPToSquin(),
3235
QASM2GlobParallelToSquin(),
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .uop_to_squin import QASM2UOPToSquin as QASM2UOPToSquin
22
from .core_to_squin import QASM2CoreToSquin as QASM2CoreToSquin
3-
from .expr_to_squin import QASM2ExprToSquin as QASM2ExprToSquin
43
from .noise_to_squin import QASM2NoiseToSquin as QASM2NoiseToSquin
54
from .glob_parallel_to_squin import QASM2GlobParallelToSquin as QASM2GlobParallelToSquin

src/bloqade/squin/rewrite/qasm2/expr_to_squin.py

Lines changed: 0 additions & 113 deletions
This file was deleted.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import io
2+
3+
from kirin import ir
4+
5+
from bloqade import stim, squin
6+
from bloqade.stim.emit import EmitStimMain
7+
from bloqade.stim.passes import SquinToStimPass
8+
9+
10+
def codegen(mt: ir.Method):
11+
# method should not have any arguments!
12+
buf = io.StringIO()
13+
emit = EmitStimMain(dialects=stim.main, io=buf)
14+
emit.initialize()
15+
emit.run(mt)
16+
return buf.getvalue().strip()
17+
18+
19+
@squin.kernel
20+
def test_simple_linear():
21+
22+
qs = squin.qalloc(4)
23+
m0 = squin.broadcast.measure(qs)
24+
squin.set_detector([m0[0], m0[1]], coordinates=[0, 0])
25+
m1 = squin.broadcast.measure(qs)
26+
squin.set_detector([m1[0], m1[1]], coordinates=[1, 1])
27+
28+
29+
test_simple_linear.print()
30+
SquinToStimPass(dialects=test_simple_linear.dialects)(test_simple_linear)
31+
test_simple_linear.print()
32+
print(codegen(test_simple_linear))

test/squin/passes/test_qasm2_to_squin.py

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import math
22

33
from kirin import types as kirin_types
4-
from kirin.rewrite import Walk
5-
from kirin.dialects import py, func, math as kirin_math
4+
from kirin.dialects import py
65
from kirin.dialects.ilist import IListType
76

87
from bloqade import qasm2, squin
@@ -11,79 +10,9 @@
1110
from bloqade.squin.passes import QASM2ToSquin
1211
from bloqade.rewrite.passes import AggressiveUnroll
1312
from bloqade.analysis.address import AddressAnalysis
14-
from bloqade.squin.rewrite.qasm2 import (
15-
QASM2ExprToSquin,
16-
)
1713
from bloqade.analysis.address.lattice import AddressReg
1814

1915

20-
def test_expr_rewrite():
21-
22-
@qasm2.main
23-
def expr_program():
24-
# constants
25-
x = 0 # noqa: F841
26-
# ConstPI only added from lowering
27-
y = qasm2.dialects.expr.stmts.ConstPI() # noqa: F841
28-
z = -1.75 # noqa: F841
29-
# binary ops
30-
a = 1 + 1 # noqa: F841
31-
b = 2 * 2 # noqa: F841
32-
c = 3 - 3 # noqa: F841
33-
d = 4 / 4 # noqa: F841
34-
e = 5**2 # noqa: F841
35-
36-
# math
37-
a = 0.2
38-
qasm2.sin(a)
39-
qasm2.cos(a)
40-
qasm2.tan(a)
41-
qasm2.exp(a)
42-
qasm2.ln(a)
43-
qasm2.sqrt(a)
44-
return
45-
46-
expr_program.print()
47-
48-
Walk(QASM2ExprToSquin()).rewrite(expr_program.code)
49-
50-
expr_program.print()
51-
52-
actual_stmt_sequence = list(expr_program.callable_region.walk())
53-
54-
def is_pi_const(stmt: py.Constant):
55-
return isinstance(stmt, py.Constant) and math.isclose(stmt.value.data, math.pi)
56-
57-
assert any(is_pi_const(stmt) for stmt in actual_stmt_sequence)
58-
59-
assert qasm2.expr.ConstFloat not in actual_stmt_sequence
60-
assert qasm2.expr.ConstInt not in actual_stmt_sequence
61-
assert qasm2.expr.ConstPI not in actual_stmt_sequence
62-
63-
no_const_actual_sequence = [
64-
type(stmt)
65-
for stmt in actual_stmt_sequence
66-
if not isinstance(stmt, (py.Constant, func.ConstantNone, func.Return))
67-
]
68-
69-
expected_stmt_sequence = [
70-
py.unary.stmts.USub,
71-
py.binop.Add,
72-
py.binop.Mult,
73-
py.binop.Sub,
74-
py.binop.Div,
75-
py.binop.Pow,
76-
kirin_math.stmts.sin,
77-
kirin_math.stmts.cos,
78-
kirin_math.stmts.tan,
79-
kirin_math.stmts.exp,
80-
kirin_math.stmts.log2,
81-
kirin_math.stmts.sqrt,
82-
]
83-
84-
assert no_const_actual_sequence == expected_stmt_sequence
85-
86-
8716
def test_qasm2_core():
8817

8918
@qasm2.main(fold=False)

0 commit comments

Comments
 (0)