diff --git a/src/bloqade/gemini/dialects/logical/_interface.py b/src/bloqade/gemini/dialects/logical/_interface.py index b9bb692f..de44390a 100644 --- a/src/bloqade/gemini/dialects/logical/_interface.py +++ b/src/bloqade/gemini/dialects/logical/_interface.py @@ -1,4 +1,4 @@ -from typing import TypeVar +from typing import Any, TypeVar from kirin import lowering from kirin.dialects import ilist @@ -8,13 +8,12 @@ from .stmts import TerminalLogicalMeasurement Len = TypeVar("Len") -CodeN = TypeVar("CodeN") @lowering.wraps(TerminalLogicalMeasurement) def terminal_measure( qubits: ilist.IList[Qubit, Len], -) -> ilist.IList[ilist.IList[MeasurementResult, CodeN], Len]: +) -> ilist.IList[ilist.IList[MeasurementResult, Any], Len]: """Perform measurements on a list of logical qubits. Measurements are returned as a nested list where each member list diff --git a/src/bloqade/gemini/dialects/logical/stmts.py b/src/bloqade/gemini/dialects/logical/stmts.py index 873c1bfd..f43ef8f0 100644 --- a/src/bloqade/gemini/dialects/logical/stmts.py +++ b/src/bloqade/gemini/dialects/logical/stmts.py @@ -6,8 +6,27 @@ from ._dialect import dialect + +@statement(dialect=dialect) +class Initialize(ir.Statement): + """Initialize a list of logical qubits to an arbitrary state. + + Args: + phi (float): Angle for rotation around the Z axis + theta (float): angle for rotation around the Y axis + phi (float): angle for rotation around the Z axis + qubits (IList[QubitType, Len]): The list of logical qubits to initialize + + """ + + traits = frozenset({}) + theta: ir.SSAValue = info.argument(types.Float) + phi: ir.SSAValue = info.argument(types.Float) + lam: ir.SSAValue = info.argument(types.Float) + qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) + + Len = types.TypeVar("Len") -CodeN = types.TypeVar("CodeN") @statement(dialect=dialect) @@ -28,5 +47,5 @@ class TerminalLogicalMeasurement(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len]) result: ir.ResultValue = info.result( - ilist.IListType[ilist.IListType[MeasurementResultType, CodeN], Len] + ilist.IListType[ilist.IListType[MeasurementResultType, types.Any], Len] ) diff --git a/src/bloqade/gemini/rewrite/__init__.py b/src/bloqade/gemini/rewrite/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/gemini/rewrite/initialize.py b/src/bloqade/gemini/rewrite/initialize.py new file mode 100644 index 00000000..8202bed4 --- /dev/null +++ b/src/bloqade/gemini/rewrite/initialize.py @@ -0,0 +1,30 @@ +from kirin import ir +from kirin.rewrite import abc as rewrite_abc + +from bloqade.squin.gate.stmts import U3 + +from ..dialects.logical.stmts import Initialize + + +class __RewriteU3ToInitialize(rewrite_abc.RewriteRule): + """Rewrite U3 gates to Initialize statements. + + Note: + + This rewrite is only valid in the context of logical qubits, where the U3 gate + can be interpreted as initializing a qubit to an arbitrary state. + + The U3 gate with parameters (theta, phi, lam) is equivalent to initializing + a qubit to the state defined by those angles. + + This rewrite also assumes there are no other U3 gates acting on the same qubits + later in the circuit, as that would conflict with the initialization semantics. + + """ + + def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult: + if not isinstance(node, U3): + return rewrite_abc.RewriteResult() + + node.replace_by(Initialize(*node.args)) + return rewrite_abc.RewriteResult(has_done_something=True) diff --git a/test/gemini/test_rewrite.py b/test/gemini/test_rewrite.py new file mode 100644 index 00000000..97082824 --- /dev/null +++ b/test/gemini/test_rewrite.py @@ -0,0 +1,30 @@ +from kirin import ir, rewrite +from kirin.dialects import py + +from bloqade.test_utils import assert_nodes +from bloqade.squin.gate.stmts import U3 +from bloqade.gemini.rewrite.initialize import __RewriteU3ToInitialize +from bloqade.gemini.dialects.logical.stmts import Initialize + + +def test_rewrite_u3_to_initialize(): + theta = ir.TestValue() + phi = ir.TestValue() + qubits = ir.TestValue() + test_block = ir.Block( + [ + lam_stmt := py.Constant(1.0), + U3(theta, phi, lam_stmt.result, qubits), + ] + ) + + expected_block = ir.Block( + [ + lam_stmt := py.Constant(1.0), + Initialize(theta, phi, lam_stmt.result, qubits), + ] + ) + + rewrite.Walk(__RewriteU3ToInitialize()).rewrite(test_block) + + assert_nodes(test_block, expected_block)