Skip to content

Commit 2ff90a4

Browse files
committed
Merge branch 'main' into david/fix-logical-validation-state-prep
2 parents b7fba43 + c2a6a32 commit 2ff90a4

File tree

8 files changed

+341
-12
lines changed

8 files changed

+341
-12
lines changed

src/bloqade/gemini/dialects/logical/_interface.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TypeVar
1+
from typing import Any, TypeVar
22

33
from kirin import lowering
44
from kirin.dialects import ilist
@@ -7,14 +7,13 @@
77

88
from .stmts import TerminalLogicalMeasurement
99

10-
Len = TypeVar("Len", bound=int)
11-
CodeN = TypeVar("CodeN", bound=int)
10+
Len = TypeVar("Len")
1211

1312

1413
@lowering.wraps(TerminalLogicalMeasurement)
1514
def terminal_measure(
1615
qubits: ilist.IList[Qubit, Len],
17-
) -> ilist.IList[ilist.IList[MeasurementResult, CodeN], Len]:
16+
) -> ilist.IList[ilist.IList[MeasurementResult, Any], Len]:
1817
"""Perform measurements on a list of logical qubits.
1918
2019
Measurements are returned as a nested list where each member list

src/bloqade/gemini/dialects/logical/stmts.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,27 @@
66

77
from ._dialect import dialect
88

9-
Len = types.TypeVar("Len", bound=types.Int)
10-
CodeN = types.TypeVar("CodeN", bound=types.Int)
9+
10+
@statement(dialect=dialect)
11+
class Initialize(ir.Statement):
12+
"""Initialize a list of logical qubits to an arbitrary state.
13+
14+
Args:
15+
phi (float): Angle for rotation around the Z axis
16+
theta (float): angle for rotation around the Y axis
17+
phi (float): angle for rotation around the Z axis
18+
qubits (IList[QubitType, Len]): The list of logical qubits to initialize
19+
20+
"""
21+
22+
traits = frozenset({})
23+
theta: ir.SSAValue = info.argument(types.Float)
24+
phi: ir.SSAValue = info.argument(types.Float)
25+
lam: ir.SSAValue = info.argument(types.Float)
26+
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
27+
28+
29+
Len = types.TypeVar("Len")
1130

1231

1332
@statement(dialect=dialect)
@@ -28,5 +47,5 @@ class TerminalLogicalMeasurement(ir.Statement):
2847
traits = frozenset({lowering.FromPythonCall()})
2948
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
3049
result: ir.ResultValue = info.result(
31-
ilist.IListType[ilist.IListType[MeasurementResultType, CodeN], Len]
50+
ilist.IListType[ilist.IListType[MeasurementResultType, types.Any], Len]
3251
)

src/bloqade/gemini/rewrite/__init__.py

Whitespace-only changes.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from kirin import ir
2+
from kirin.rewrite import abc as rewrite_abc
3+
4+
from bloqade.squin.gate.stmts import U3
5+
6+
from ..dialects.logical.stmts import Initialize
7+
8+
9+
class __RewriteU3ToInitialize(rewrite_abc.RewriteRule):
10+
"""Rewrite U3 gates to Initialize statements.
11+
12+
Note:
13+
14+
This rewrite is only valid in the context of logical qubits, where the U3 gate
15+
can be interpreted as initializing a qubit to an arbitrary state.
16+
17+
The U3 gate with parameters (theta, phi, lam) is equivalent to initializing
18+
a qubit to the state defined by those angles.
19+
20+
This rewrite also assumes there are no other U3 gates acting on the same qubits
21+
later in the circuit, as that would conflict with the initialization semantics.
22+
23+
"""
24+
25+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
26+
if not isinstance(node, U3):
27+
return rewrite_abc.RewriteResult()
28+
29+
node.replace_by(Initialize(*node.args))
30+
return rewrite_abc.RewriteResult(has_done_something=True)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from kirin import ir
2+
from kirin.rewrite import abc as rewrite_abc
3+
from kirin.dialects import py
4+
5+
from bloqade.squin.gate import stmts as gate_stmts
6+
7+
8+
class RewriteNonCliffordToU3(rewrite_abc.RewriteRule):
9+
"""Rewrite non-Clifford gates to U3 gates.
10+
11+
This rewrite rule transforms specific non-Clifford single-qubit gates
12+
into equivalent U3 gate representations. The following transformations are applied:
13+
- T gate (with adjoint attribute) to U3 gate with parameters (0, 0, ±π/4)
14+
- Rx gate to U3 gate with parameters (angle, -π/2, π/2)
15+
- Ry gate to U3 gate with parameters (angle, 0, 0)
16+
- Rz gate is U3 gate with parameters (0, 0, angle)
17+
18+
This rewrite should be paired with `U3ToClifford` to canonicalize the circuit.
19+
20+
"""
21+
22+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
23+
if not isinstance(
24+
node,
25+
(
26+
gate_stmts.T,
27+
gate_stmts.Rx,
28+
gate_stmts.Ry,
29+
gate_stmts.Rz,
30+
),
31+
):
32+
return rewrite_abc.RewriteResult()
33+
34+
rule = getattr(self, f"rewrite_{type(node).__name__}")
35+
36+
return rule(node)
37+
38+
def rewrite_T(self, node: gate_stmts.T) -> rewrite_abc.RewriteResult:
39+
if node.adjoint:
40+
lam_value = -1.0 / 8.0
41+
else:
42+
lam_value = 1.0 / 8.0
43+
44+
(theta_stmt := py.Constant(0.0)).insert_before(node)
45+
(phi_stmt := py.Constant(0.0)).insert_before(node)
46+
(lam_stmt := py.Constant(lam_value)).insert_before(node)
47+
48+
node.replace_by(
49+
gate_stmts.U3(
50+
qubits=node.qubits,
51+
theta=theta_stmt.result,
52+
phi=phi_stmt.result,
53+
lam=lam_stmt.result,
54+
)
55+
)
56+
57+
return rewrite_abc.RewriteResult(has_done_something=True)
58+
59+
def rewrite_Rx(self, node: gate_stmts.Rx) -> rewrite_abc.RewriteResult:
60+
(phi_stmt := py.Constant(-0.25)).insert_before(node)
61+
(lam_stmt := py.Constant(0.25)).insert_before(node)
62+
63+
node.replace_by(
64+
gate_stmts.U3(
65+
qubits=node.qubits,
66+
theta=node.angle,
67+
phi=phi_stmt.result,
68+
lam=lam_stmt.result,
69+
)
70+
)
71+
72+
return rewrite_abc.RewriteResult(has_done_something=True)
73+
74+
def rewrite_Ry(self, node: gate_stmts.Ry) -> rewrite_abc.RewriteResult:
75+
(phi_stmt := py.Constant(0.0)).insert_before(node)
76+
(lam_stmt := py.Constant(0.0)).insert_before(node)
77+
78+
node.replace_by(
79+
gate_stmts.U3(
80+
qubits=node.qubits,
81+
theta=node.angle,
82+
phi=phi_stmt.result,
83+
lam=lam_stmt.result,
84+
)
85+
)
86+
87+
return rewrite_abc.RewriteResult(has_done_something=True)
88+
89+
def rewrite_Rz(self, node: gate_stmts.Rz) -> rewrite_abc.RewriteResult:
90+
(theta_stmt := py.Constant(0.0)).insert_before(node)
91+
(phi_stmt := py.Constant(0.0)).insert_before(node)
92+
93+
node.replace_by(
94+
gate_stmts.U3(
95+
qubits=node.qubits,
96+
theta=theta_stmt.result,
97+
phi=phi_stmt.result,
98+
lam=node.angle,
99+
)
100+
)
101+
102+
return rewrite_abc.RewriteResult(has_done_something=True)

test/gemini/test_logical_validation.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,20 @@ def invalid():
119119
invalid.print(analysis=frame.entries)
120120

121121

122-
def test_terminal_measurement():
123-
@gemini.logical.kernel(verify=False)
122+
def test_qalloc_and_terminal_measure_type_valid():
123+
124+
@gemini.logical.kernel(aggressive_unroll=True)
124125
def main():
125126
q = squin.qalloc(3)
126-
m = gemini.logical.terminal_measure(q)
127-
return m
127+
gemini.logical.terminal_measure(q)
128128

129-
main.print()
129+
validator = ValidationSuite([GeminiTerminalMeasurementValidation])
130+
validation_result = validator.validate(main)
131+
132+
validation_result.raise_if_invalid()
133+
134+
135+
def test_terminal_measurement():
130136

131137
@gemini.logical.kernel(
132138
verify=False, no_raise=False, aggressive_unroll=True, typeinfer=True

test/gemini/test_rewrite.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from kirin import ir, rewrite
2+
from kirin.dialects import py
3+
4+
from bloqade.test_utils import assert_nodes
5+
from bloqade.squin.gate.stmts import U3
6+
from bloqade.gemini.rewrite.initialize import __RewriteU3ToInitialize
7+
from bloqade.gemini.dialects.logical.stmts import Initialize
8+
9+
10+
def test_rewrite_u3_to_initialize():
11+
theta = ir.TestValue()
12+
phi = ir.TestValue()
13+
qubits = ir.TestValue()
14+
test_block = ir.Block(
15+
[
16+
lam_stmt := py.Constant(1.0),
17+
U3(theta, phi, lam_stmt.result, qubits),
18+
]
19+
)
20+
21+
expected_block = ir.Block(
22+
[
23+
lam_stmt := py.Constant(1.0),
24+
Initialize(theta, phi, lam_stmt.result, qubits),
25+
]
26+
)
27+
28+
rewrite.Walk(__RewriteU3ToInitialize()).rewrite(test_block)
29+
30+
assert_nodes(test_block, expected_block)
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from kirin import ir, rewrite
2+
from kirin.dialects import py
3+
4+
from bloqade.squin.gate import stmts as gate_stmts
5+
from bloqade.test_utils import assert_nodes
6+
from bloqade.squin.rewrite.non_clifford_to_U3 import RewriteNonCliffordToU3
7+
8+
9+
def test_rewrite_T():
10+
test_qubits = ir.TestValue()
11+
test_block = ir.Block([gate_stmts.T(qubits=test_qubits, adjoint=False)])
12+
13+
expected_block = ir.Block(
14+
[
15+
theta := py.Constant(0.0),
16+
phi := py.Constant(0.0),
17+
lam := py.Constant(1.0 / 8.0),
18+
gate_stmts.U3(
19+
qubits=test_qubits,
20+
theta=theta.result,
21+
phi=phi.result,
22+
lam=lam.result,
23+
),
24+
]
25+
)
26+
27+
rule = rewrite.Walk(RewriteNonCliffordToU3())
28+
rule.rewrite(test_block)
29+
30+
assert_nodes(test_block, expected_block)
31+
32+
33+
def test_rewrite_Tadj():
34+
test_qubits = ir.TestValue()
35+
test_block = ir.Block([gate_stmts.T(qubits=test_qubits, adjoint=True)])
36+
37+
expected_block = ir.Block(
38+
[
39+
theta := py.Constant(0.0),
40+
phi := py.Constant(0.0),
41+
lam := py.Constant(-1.0 / 8.0),
42+
gate_stmts.U3(
43+
qubits=test_qubits,
44+
theta=theta.result,
45+
phi=phi.result,
46+
lam=lam.result,
47+
),
48+
]
49+
)
50+
51+
rule = rewrite.Walk(RewriteNonCliffordToU3())
52+
rule.rewrite(test_block)
53+
54+
assert_nodes(test_block, expected_block)
55+
56+
57+
def test_rewrite_Ry():
58+
test_qubits = ir.TestValue()
59+
angle = ir.TestValue()
60+
test_block = ir.Block([gate_stmts.Ry(qubits=test_qubits, angle=angle)])
61+
62+
expected_block = ir.Block(
63+
[
64+
phi := py.Constant(0.0),
65+
lam := py.Constant(0.0),
66+
gate_stmts.U3(
67+
qubits=test_qubits,
68+
theta=angle,
69+
phi=phi.result,
70+
lam=lam.result,
71+
),
72+
]
73+
)
74+
75+
rule = rewrite.Walk(RewriteNonCliffordToU3())
76+
rule.rewrite(test_block)
77+
78+
assert_nodes(test_block, expected_block)
79+
80+
81+
def test_rewrite_Rz():
82+
test_qubits = ir.TestValue()
83+
angle = ir.TestValue()
84+
test_block = ir.Block([gate_stmts.Rz(qubits=test_qubits, angle=angle)])
85+
86+
expected_block = ir.Block(
87+
[
88+
theta := py.Constant(0.0),
89+
phi := py.Constant(0.0),
90+
gate_stmts.U3(
91+
qubits=test_qubits,
92+
theta=theta.result,
93+
phi=phi.result,
94+
lam=angle,
95+
),
96+
]
97+
)
98+
99+
rule = rewrite.Walk(RewriteNonCliffordToU3())
100+
rule.rewrite(test_block)
101+
102+
assert_nodes(test_block, expected_block)
103+
104+
105+
def test_rewrite_Rx():
106+
test_qubits = ir.TestValue()
107+
angle = ir.TestValue()
108+
test_block = ir.Block([gate_stmts.Rx(qubits=test_qubits, angle=angle)])
109+
110+
expected_block = ir.Block(
111+
[
112+
phi := py.Constant(-0.25),
113+
lam := py.Constant(0.25),
114+
gate_stmts.U3(
115+
qubits=test_qubits,
116+
theta=angle,
117+
phi=phi.result,
118+
lam=lam.result,
119+
),
120+
]
121+
)
122+
123+
rule = rewrite.Walk(RewriteNonCliffordToU3())
124+
rule.rewrite(test_block)
125+
126+
assert_nodes(test_block, expected_block)
127+
128+
129+
def test_no_op():
130+
test_qubits = ir.TestValue()
131+
angle = ir.TestValue()
132+
test_block = ir.Block(
133+
[gate_stmts.U3(qubits=test_qubits, theta=angle, phi=angle, lam=angle)]
134+
)
135+
136+
expected_block = ir.Block(
137+
[gate_stmts.U3(qubits=test_qubits, theta=angle, phi=angle, lam=angle)]
138+
)
139+
140+
rule = rewrite.Walk(RewriteNonCliffordToU3())
141+
rule.rewrite(test_block)
142+
143+
assert_nodes(test_block, expected_block)

0 commit comments

Comments
 (0)