Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/bloqade/gemini/dialects/logical/_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar
from typing import Any, TypeVar

from kirin import lowering
from kirin.dialects import ilist
Expand All @@ -8,13 +8,12 @@
from .stmts import TerminalLogicalMeasurement

Len = TypeVar("Len")
CodeN = TypeVar("CodeN")


@lowering.wraps(TerminalLogicalMeasurement)
def terminal_measure(
qubits: ilist.IList[Qubit, Len],
) -> ilist.IList[ilist.IList[MeasurementResult, CodeN], Len]:
) -> ilist.IList[ilist.IList[MeasurementResult, Any], Len]:
"""Perform measurements on a list of logical qubits.

Measurements are returned as a nested list where each member list
Expand Down
23 changes: 21 additions & 2 deletions src/bloqade/gemini/dialects/logical/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,27 @@

from ._dialect import dialect


@statement(dialect=dialect)
class Initialize(ir.Statement):
"""Initialize a list of logical qubits to an arbitrary state.

Args:
phi (float): Angle for rotation around the Z axis
theta (float): angle for rotation around the Y axis
phi (float): angle for rotation around the Z axis
qubits (IList[QubitType, Len]): The list of logical qubits to initialize

"""

traits = frozenset({})
theta: ir.SSAValue = info.argument(types.Float)
phi: ir.SSAValue = info.argument(types.Float)
lam: ir.SSAValue = info.argument(types.Float)
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])


Len = types.TypeVar("Len")
CodeN = types.TypeVar("CodeN")


@statement(dialect=dialect)
Expand All @@ -28,5 +47,5 @@ class TerminalLogicalMeasurement(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len])
result: ir.ResultValue = info.result(
ilist.IListType[ilist.IListType[MeasurementResultType, CodeN], Len]
ilist.IListType[ilist.IListType[MeasurementResultType, types.Any], Len]
)
Empty file.
30 changes: 30 additions & 0 deletions src/bloqade/gemini/rewrite/initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from kirin import ir
from kirin.rewrite import abc as rewrite_abc

from bloqade.squin.gate.stmts import U3

from ..dialects.logical.stmts import Initialize


class __RewriteU3ToInitialize(rewrite_abc.RewriteRule):
"""Rewrite U3 gates to Initialize statements.

Note:

This rewrite is only valid in the context of logical qubits, where the U3 gate
can be interpreted as initializing a qubit to an arbitrary state.

The U3 gate with parameters (theta, phi, lam) is equivalent to initializing
a qubit to the state defined by those angles.

This rewrite also assumes there are no other U3 gates acting on the same qubits
later in the circuit, as that would conflict with the initialization semantics.

"""

def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
if not isinstance(node, U3):
return rewrite_abc.RewriteResult()

node.replace_by(Initialize(*node.args))
return rewrite_abc.RewriteResult(has_done_something=True)
30 changes: 30 additions & 0 deletions test/gemini/test_rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from kirin import ir, rewrite
from kirin.dialects import py

from bloqade.test_utils import assert_nodes
from bloqade.squin.gate.stmts import U3
from bloqade.gemini.rewrite.initialize import __RewriteU3ToInitialize
from bloqade.gemini.dialects.logical.stmts import Initialize


def test_rewrite_u3_to_initialize():
theta = ir.TestValue()
phi = ir.TestValue()
qubits = ir.TestValue()
test_block = ir.Block(
[
lam_stmt := py.Constant(1.0),
U3(theta, phi, lam_stmt.result, qubits),
]
)

expected_block = ir.Block(
[
lam_stmt := py.Constant(1.0),
Initialize(theta, phi, lam_stmt.result, qubits),
]
)

rewrite.Walk(__RewriteU3ToInitialize()).rewrite(test_block)

assert_nodes(test_block, expected_block)
Loading