Skip to content

Commit 8c5f453

Browse files
committed
Merge branch 'main' into john/repeat-support
2 parents a39a126 + c2a6a32 commit 8c5f453

File tree

5 files changed

+83
-5
lines changed

5 files changed

+83
-5
lines changed

src/bloqade/gemini/dialects/logical/_interface.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TypeVar
1+
from typing import Any, TypeVar
22

33
from kirin import lowering
44
from kirin.dialects import ilist
@@ -8,13 +8,12 @@
88
from .stmts import TerminalLogicalMeasurement
99

1010
Len = TypeVar("Len")
11-
CodeN = TypeVar("CodeN")
1211

1312

1413
@lowering.wraps(TerminalLogicalMeasurement)
1514
def terminal_measure(
1615
qubits: ilist.IList[Qubit, Len],
17-
) -> ilist.IList[ilist.IList[MeasurementResult, CodeN], Len]:
16+
) -> ilist.IList[ilist.IList[MeasurementResult, Any], Len]:
1817
"""Perform measurements on a list of logical qubits.
1918
2019
Measurements are returned as a nested list where each member list

src/bloqade/gemini/dialects/logical/stmts.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,27 @@
66

77
from ._dialect import dialect
88

9+
10+
@statement(dialect=dialect)
11+
class Initialize(ir.Statement):
12+
"""Initialize a list of logical qubits to an arbitrary state.
13+
14+
Args:
15+
phi (float): Angle for rotation around the Z axis
16+
theta (float): angle for rotation around the Y axis
17+
phi (float): angle for rotation around the Z axis
18+
qubits (IList[QubitType, Len]): The list of logical qubits to initialize
19+
20+
"""
21+
22+
traits = frozenset({})
23+
theta: ir.SSAValue = info.argument(types.Float)
24+
phi: ir.SSAValue = info.argument(types.Float)
25+
lam: ir.SSAValue = info.argument(types.Float)
26+
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
27+
28+
929
Len = types.TypeVar("Len")
10-
CodeN = types.TypeVar("CodeN")
1130

1231

1332
@statement(dialect=dialect)
@@ -28,5 +47,5 @@ class TerminalLogicalMeasurement(ir.Statement):
2847
traits = frozenset({lowering.FromPythonCall()})
2948
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
3049
result: ir.ResultValue = info.result(
31-
ilist.IListType[ilist.IListType[MeasurementResultType, CodeN], Len]
50+
ilist.IListType[ilist.IListType[MeasurementResultType, types.Any], Len]
3251
)

src/bloqade/gemini/rewrite/__init__.py

Whitespace-only changes.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from kirin import ir
2+
from kirin.rewrite import abc as rewrite_abc
3+
4+
from bloqade.squin.gate.stmts import U3
5+
6+
from ..dialects.logical.stmts import Initialize
7+
8+
9+
class __RewriteU3ToInitialize(rewrite_abc.RewriteRule):
10+
"""Rewrite U3 gates to Initialize statements.
11+
12+
Note:
13+
14+
This rewrite is only valid in the context of logical qubits, where the U3 gate
15+
can be interpreted as initializing a qubit to an arbitrary state.
16+
17+
The U3 gate with parameters (theta, phi, lam) is equivalent to initializing
18+
a qubit to the state defined by those angles.
19+
20+
This rewrite also assumes there are no other U3 gates acting on the same qubits
21+
later in the circuit, as that would conflict with the initialization semantics.
22+
23+
"""
24+
25+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
26+
if not isinstance(node, U3):
27+
return rewrite_abc.RewriteResult()
28+
29+
node.replace_by(Initialize(*node.args))
30+
return rewrite_abc.RewriteResult(has_done_something=True)

test/gemini/test_rewrite.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from kirin import ir, rewrite
2+
from kirin.dialects import py
3+
4+
from bloqade.test_utils import assert_nodes
5+
from bloqade.squin.gate.stmts import U3
6+
from bloqade.gemini.rewrite.initialize import __RewriteU3ToInitialize
7+
from bloqade.gemini.dialects.logical.stmts import Initialize
8+
9+
10+
def test_rewrite_u3_to_initialize():
11+
theta = ir.TestValue()
12+
phi = ir.TestValue()
13+
qubits = ir.TestValue()
14+
test_block = ir.Block(
15+
[
16+
lam_stmt := py.Constant(1.0),
17+
U3(theta, phi, lam_stmt.result, qubits),
18+
]
19+
)
20+
21+
expected_block = ir.Block(
22+
[
23+
lam_stmt := py.Constant(1.0),
24+
Initialize(theta, phi, lam_stmt.result, qubits),
25+
]
26+
)
27+
28+
rewrite.Walk(__RewriteU3ToInitialize()).rewrite(test_block)
29+
30+
assert_nodes(test_block, expected_block)

0 commit comments

Comments
 (0)