Skip to content

Commit 0645d11

Browse files
committed
Merge branch 'main' into john/qasm2-to-squin
2 parents 6390a23 + ca47d3d commit 0645d11

File tree

5 files changed

+260
-9
lines changed

5 files changed

+260
-9
lines changed

src/bloqade/gemini/dialects/logical/_interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
from .stmts import TerminalLogicalMeasurement
99

10-
Len = TypeVar("Len", bound=int)
11-
CodeN = TypeVar("CodeN", bound=int)
10+
Len = TypeVar("Len")
11+
CodeN = TypeVar("CodeN")
1212

1313

1414
@lowering.wraps(TerminalLogicalMeasurement)

src/bloqade/gemini/dialects/logical/stmts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from ._dialect import dialect
88

9-
Len = types.TypeVar("Len", bound=types.Int)
10-
CodeN = types.TypeVar("CodeN", bound=types.Int)
9+
Len = types.TypeVar("Len")
10+
CodeN = types.TypeVar("CodeN")
1111

1212

1313
@statement(dialect=dialect)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from kirin import ir
2+
from kirin.rewrite import abc as rewrite_abc
3+
from kirin.dialects import py
4+
5+
from bloqade.squin.gate import stmts as gate_stmts
6+
7+
8+
class RewriteNonCliffordToU3(rewrite_abc.RewriteRule):
9+
"""Rewrite non-Clifford gates to U3 gates.
10+
11+
This rewrite rule transforms specific non-Clifford single-qubit gates
12+
into equivalent U3 gate representations. The following transformations are applied:
13+
- T gate (with adjoint attribute) to U3 gate with parameters (0, 0, ±π/4)
14+
- Rx gate to U3 gate with parameters (angle, -π/2, π/2)
15+
- Ry gate to U3 gate with parameters (angle, 0, 0)
16+
- Rz gate is U3 gate with parameters (0, 0, angle)
17+
18+
This rewrite should be paired with `U3ToClifford` to canonicalize the circuit.
19+
20+
"""
21+
22+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
23+
if not isinstance(
24+
node,
25+
(
26+
gate_stmts.T,
27+
gate_stmts.Rx,
28+
gate_stmts.Ry,
29+
gate_stmts.Rz,
30+
),
31+
):
32+
return rewrite_abc.RewriteResult()
33+
34+
rule = getattr(self, f"rewrite_{type(node).__name__}")
35+
36+
return rule(node)
37+
38+
def rewrite_T(self, node: gate_stmts.T) -> rewrite_abc.RewriteResult:
39+
if node.adjoint:
40+
lam_value = -1.0 / 8.0
41+
else:
42+
lam_value = 1.0 / 8.0
43+
44+
(theta_stmt := py.Constant(0.0)).insert_before(node)
45+
(phi_stmt := py.Constant(0.0)).insert_before(node)
46+
(lam_stmt := py.Constant(lam_value)).insert_before(node)
47+
48+
node.replace_by(
49+
gate_stmts.U3(
50+
qubits=node.qubits,
51+
theta=theta_stmt.result,
52+
phi=phi_stmt.result,
53+
lam=lam_stmt.result,
54+
)
55+
)
56+
57+
return rewrite_abc.RewriteResult(has_done_something=True)
58+
59+
def rewrite_Rx(self, node: gate_stmts.Rx) -> rewrite_abc.RewriteResult:
60+
(phi_stmt := py.Constant(-0.25)).insert_before(node)
61+
(lam_stmt := py.Constant(0.25)).insert_before(node)
62+
63+
node.replace_by(
64+
gate_stmts.U3(
65+
qubits=node.qubits,
66+
theta=node.angle,
67+
phi=phi_stmt.result,
68+
lam=lam_stmt.result,
69+
)
70+
)
71+
72+
return rewrite_abc.RewriteResult(has_done_something=True)
73+
74+
def rewrite_Ry(self, node: gate_stmts.Ry) -> rewrite_abc.RewriteResult:
75+
(phi_stmt := py.Constant(0.0)).insert_before(node)
76+
(lam_stmt := py.Constant(0.0)).insert_before(node)
77+
78+
node.replace_by(
79+
gate_stmts.U3(
80+
qubits=node.qubits,
81+
theta=node.angle,
82+
phi=phi_stmt.result,
83+
lam=lam_stmt.result,
84+
)
85+
)
86+
87+
return rewrite_abc.RewriteResult(has_done_something=True)
88+
89+
def rewrite_Rz(self, node: gate_stmts.Rz) -> rewrite_abc.RewriteResult:
90+
(theta_stmt := py.Constant(0.0)).insert_before(node)
91+
(phi_stmt := py.Constant(0.0)).insert_before(node)
92+
93+
node.replace_by(
94+
gate_stmts.U3(
95+
qubits=node.qubits,
96+
theta=theta_stmt.result,
97+
phi=phi_stmt.result,
98+
lam=node.angle,
99+
)
100+
)
101+
102+
return rewrite_abc.RewriteResult(has_done_something=True)

test/gemini/test_logical_validation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,20 @@ def invalid():
115115
invalid.print(analysis=frame.entries)
116116

117117

118-
def test_terminal_measurement():
119-
@gemini.logical.kernel(verify=False)
118+
def test_qalloc_and_terminal_measure_type_valid():
119+
120+
@gemini.logical.kernel(aggressive_unroll=True)
120121
def main():
121122
q = squin.qalloc(3)
122-
m = gemini.logical.terminal_measure(q)
123-
return m
123+
gemini.logical.terminal_measure(q)
124124

125-
main.print()
125+
validator = ValidationSuite([GeminiTerminalMeasurementValidation])
126+
validation_result = validator.validate(main)
127+
128+
validation_result.raise_if_invalid()
129+
130+
131+
def test_terminal_measurement():
126132

127133
@gemini.logical.kernel(
128134
verify=False, no_raise=False, aggressive_unroll=True, typeinfer=True
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from kirin import ir, rewrite
2+
from kirin.dialects import py
3+
4+
from bloqade.squin.gate import stmts as gate_stmts
5+
from bloqade.test_utils import assert_nodes
6+
from bloqade.squin.rewrite.non_clifford_to_U3 import RewriteNonCliffordToU3
7+
8+
9+
def test_rewrite_T():
10+
test_qubits = ir.TestValue()
11+
test_block = ir.Block([gate_stmts.T(qubits=test_qubits, adjoint=False)])
12+
13+
expected_block = ir.Block(
14+
[
15+
theta := py.Constant(0.0),
16+
phi := py.Constant(0.0),
17+
lam := py.Constant(1.0 / 8.0),
18+
gate_stmts.U3(
19+
qubits=test_qubits,
20+
theta=theta.result,
21+
phi=phi.result,
22+
lam=lam.result,
23+
),
24+
]
25+
)
26+
27+
rule = rewrite.Walk(RewriteNonCliffordToU3())
28+
rule.rewrite(test_block)
29+
30+
assert_nodes(test_block, expected_block)
31+
32+
33+
def test_rewrite_Tadj():
34+
test_qubits = ir.TestValue()
35+
test_block = ir.Block([gate_stmts.T(qubits=test_qubits, adjoint=True)])
36+
37+
expected_block = ir.Block(
38+
[
39+
theta := py.Constant(0.0),
40+
phi := py.Constant(0.0),
41+
lam := py.Constant(-1.0 / 8.0),
42+
gate_stmts.U3(
43+
qubits=test_qubits,
44+
theta=theta.result,
45+
phi=phi.result,
46+
lam=lam.result,
47+
),
48+
]
49+
)
50+
51+
rule = rewrite.Walk(RewriteNonCliffordToU3())
52+
rule.rewrite(test_block)
53+
54+
assert_nodes(test_block, expected_block)
55+
56+
57+
def test_rewrite_Ry():
58+
test_qubits = ir.TestValue()
59+
angle = ir.TestValue()
60+
test_block = ir.Block([gate_stmts.Ry(qubits=test_qubits, angle=angle)])
61+
62+
expected_block = ir.Block(
63+
[
64+
phi := py.Constant(0.0),
65+
lam := py.Constant(0.0),
66+
gate_stmts.U3(
67+
qubits=test_qubits,
68+
theta=angle,
69+
phi=phi.result,
70+
lam=lam.result,
71+
),
72+
]
73+
)
74+
75+
rule = rewrite.Walk(RewriteNonCliffordToU3())
76+
rule.rewrite(test_block)
77+
78+
assert_nodes(test_block, expected_block)
79+
80+
81+
def test_rewrite_Rz():
82+
test_qubits = ir.TestValue()
83+
angle = ir.TestValue()
84+
test_block = ir.Block([gate_stmts.Rz(qubits=test_qubits, angle=angle)])
85+
86+
expected_block = ir.Block(
87+
[
88+
theta := py.Constant(0.0),
89+
phi := py.Constant(0.0),
90+
gate_stmts.U3(
91+
qubits=test_qubits,
92+
theta=theta.result,
93+
phi=phi.result,
94+
lam=angle,
95+
),
96+
]
97+
)
98+
99+
rule = rewrite.Walk(RewriteNonCliffordToU3())
100+
rule.rewrite(test_block)
101+
102+
assert_nodes(test_block, expected_block)
103+
104+
105+
def test_rewrite_Rx():
106+
test_qubits = ir.TestValue()
107+
angle = ir.TestValue()
108+
test_block = ir.Block([gate_stmts.Rx(qubits=test_qubits, angle=angle)])
109+
110+
expected_block = ir.Block(
111+
[
112+
phi := py.Constant(-0.25),
113+
lam := py.Constant(0.25),
114+
gate_stmts.U3(
115+
qubits=test_qubits,
116+
theta=angle,
117+
phi=phi.result,
118+
lam=lam.result,
119+
),
120+
]
121+
)
122+
123+
rule = rewrite.Walk(RewriteNonCliffordToU3())
124+
rule.rewrite(test_block)
125+
126+
assert_nodes(test_block, expected_block)
127+
128+
129+
def test_no_op():
130+
test_qubits = ir.TestValue()
131+
angle = ir.TestValue()
132+
test_block = ir.Block(
133+
[gate_stmts.U3(qubits=test_qubits, theta=angle, phi=angle, lam=angle)]
134+
)
135+
136+
expected_block = ir.Block(
137+
[gate_stmts.U3(qubits=test_qubits, theta=angle, phi=angle, lam=angle)]
138+
)
139+
140+
rule = rewrite.Walk(RewriteNonCliffordToU3())
141+
rule.rewrite(test_block)
142+
143+
assert_nodes(test_block, expected_block)

0 commit comments

Comments
 (0)