diff --git a/src/bloqade/lanes/dialects/circuit.py b/src/bloqade/lanes/dialects/circuit.py index 65a64c5d..28759aeb 100644 --- a/src/bloqade/lanes/dialects/circuit.py +++ b/src/bloqade/lanes/dialects/circuit.py @@ -1,27 +1,28 @@ from kirin import ir, types from kirin.decl import info, statement -from .execute import ExitLowLevel, LowLevelStmt +from .execute import ExitLowLevel, QuantumStmt dialect = ir.Dialect(name="lowlevel.circuit") @statement(dialect=dialect) -class CZ(LowLevelStmt): - pairs: tuple[tuple[int, int], ...] = info.attribute() +class CZ(QuantumStmt): + targets: tuple[int, ...] = info.attribute() + controls: tuple[int, ...] = info.attribute() @statement(dialect=dialect) -class R(LowLevelStmt): - inputs: tuple[int, ...] = info.attribute() +class R(QuantumStmt): + qubits: tuple[int, ...] = info.attribute() axis_angle: ir.SSAValue = info.argument(type=types.Float) rotation_angle: ir.SSAValue = info.argument(type=types.Float) @statement(dialect=dialect) -class Rz(LowLevelStmt): - inputs: tuple[int, ...] = info.attribute() +class Rz(QuantumStmt): + qubits: tuple[int, ...] = info.attribute() rotation_angle: ir.SSAValue = info.argument(type=types.Float) diff --git a/src/bloqade/lanes/dialects/cpu.py b/src/bloqade/lanes/dialects/cpu.py index 910a4934..d89a24f3 100644 --- a/src/bloqade/lanes/dialects/cpu.py +++ b/src/bloqade/lanes/dialects/cpu.py @@ -1,23 +1,12 @@ from kirin import ir, types from kirin.decl import info, statement -from .execute import LowLevelStmt - dialect = ir.Dialect(name="lowlevel.cpu") @statement(dialect=dialect) -class StaticFloat(LowLevelStmt): +class StaticFloat(ir.Statement): traits = frozenset({ir.ConstantLike()}) value: float = info.attribute(type=types.Float) result: ir.ResultValue = info.result(type=types.Float) - - -@statement(dialect=dialect) -class StaticInt(LowLevelStmt): - traits = frozenset({ir.ConstantLike()}) - - value: int = info.attribute(type=types.Int) - - result: ir.ResultValue = info.result(type=types.Int) diff --git a/src/bloqade/lanes/dialects/execute.py b/src/bloqade/lanes/dialects/execute.py index 7042ad9d..dd5eac3f 100644 --- a/src/bloqade/lanes/dialects/execute.py +++ b/src/bloqade/lanes/dialects/execute.py @@ -9,11 +9,11 @@ @statement -class LowLevelStmt(ir.Statement): +class QuantumStmt(ir.Statement): """This is a base class for all low level statements.""" state_before: ir.SSAValue = info.argument(StateType) - result: ir.ResultValue = info.result(StateType) + state_after: ir.ResultValue = info.result(StateType) @statement @@ -39,7 +39,6 @@ class ExecuteLowLevel(ir.Statement): qubits: tuple[ir.SSAValue, ...] = info.argument(bloqade_types.QubitType) body: ir.Region = info.region(multi=False) starting_addresses: tuple[LocationAddress, ...] | None = info.attribute() - measure_result: tuple[ir.ResultValue, ...] = info.result() def __init__( self, @@ -81,7 +80,7 @@ def check(self) -> None: stmt = body_block.first_stmt while stmt is not last_stmt: - if not isinstance(stmt, LowLevelStmt): + if not isinstance(stmt, QuantumStmt): raise exception.StaticCheckError( "All statements in ShuttleAtoms body must be ByteCodeStmt" ) diff --git a/src/bloqade/lanes/dialects/move.py b/src/bloqade/lanes/dialects/move.py index 67e1db58..c3dab541 100644 --- a/src/bloqade/lanes/dialects/move.py +++ b/src/bloqade/lanes/dialects/move.py @@ -2,18 +2,18 @@ from kirin.decl import info, statement from ..layout.encoding import LocationAddress, MoveType -from .execute import ExitLowLevel, LowLevelStmt +from .execute import ExitLowLevel, QuantumStmt dialect = ir.Dialect(name="lowlevel.move") @statement(dialect=dialect) -class CZ(LowLevelStmt): +class CZ(QuantumStmt): pass @statement(dialect=dialect) -class LocalR(LowLevelStmt): +class LocalR(QuantumStmt): physical_addr: tuple[LocationAddress, ...] = info.attribute() axis_angle: ir.SSAValue = info.argument(type=types.Float) @@ -21,23 +21,23 @@ class LocalR(LowLevelStmt): @statement(dialect=dialect) -class GlobalR(LowLevelStmt): +class GlobalR(QuantumStmt): axis_angle: ir.SSAValue = info.argument(type=types.Float) rotation_angle: ir.SSAValue = info.argument(type=types.Float) @statement(dialect=dialect) -class LocalRz(LowLevelStmt): +class LocalRz(QuantumStmt): physical_addr: tuple[LocationAddress, ...] = info.attribute() rotation_angle: ir.SSAValue = info.argument(type=types.Float) @statement(dialect=dialect) -class GlobalRz(LowLevelStmt): +class GlobalRz(QuantumStmt): rotation_angle: ir.SSAValue = info.argument(type=types.Float) -class Move(LowLevelStmt): +class Move(QuantumStmt): lanes: tuple[MoveType, ...] = info.attribute() diff --git a/src/bloqade/lanes/passes/__init__.py b/src/bloqade/lanes/passes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/lanes/passes/canonicalize.py b/src/bloqade/lanes/passes/canonicalize.py new file mode 100644 index 00000000..1565daa5 --- /dev/null +++ b/src/bloqade/lanes/passes/canonicalize.py @@ -0,0 +1,53 @@ +from dataclasses import dataclass + +from bloqade.native.dialects.gate import stmts as gates +from kirin import ir, passes, rewrite +from kirin.passes.hint_const import HintConst +from kirin.rewrite import abc + + +class HoistClassicalStatements(abc.RewriteRule): + """This rewrite rule shift any classical statements that are pure + (quantum statements are never pure) to the beginning of the block. + swapping the other with quantum statements. This is useful after + rewriting the native operations to placement operations, + so that we can merge the placement regions together. + + Note that this rule also works with subroutines that contain + quantum statements because these are also not pure + + """ + + def is_pure(self, node: ir.Statement) -> bool: + return ( + node.has_trait(ir.Pure) + or (maybe_pure := node.get_trait(ir.MaybePure)) is not None + and maybe_pure.is_pure(node) + ) + + def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult: + if not ( + isinstance(node, (gates.CZ, gates.R, gates.Rz)) + and (next_node := node.next_stmt) is not None + and not next_node.has_trait(ir.IsTerminator) + and self.is_pure(next_node) + ): + return abc.RewriteResult() + + next_node.detach() + next_node.insert_before(node) + return abc.RewriteResult(has_done_something=True) + + +@dataclass +class CanonicalizeNative(passes.Pass): + + def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult: + result = HintConst(mt.dialects)(mt) + result = ( + rewrite.Fixpoint(rewrite.Walk(HoistClassicalStatements())) + .rewrite(mt.code) + .join(result) + ) + + return result diff --git a/src/bloqade/lanes/rewrite/__init__.py b/src/bloqade/lanes/rewrite/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/lanes/rewrite/native2circuit.py b/src/bloqade/lanes/rewrite/native2circuit.py new file mode 100644 index 00000000..083a513d --- /dev/null +++ b/src/bloqade/lanes/rewrite/native2circuit.py @@ -0,0 +1,208 @@ +from dataclasses import dataclass + +from bloqade.native.dialects.gate import stmts as gate +from kirin import ir +from kirin.dialects import ilist, py +from kirin.rewrite import abc + +from bloqade.lanes.dialects import circuit, cpu, execute +from bloqade.lanes.types import StateType + + +@dataclass +class RewriteLowLevelCircuit(abc.RewriteRule): + """ + Rewrite rule to convert native operations to placement operations. + This is a placeholder for the actual implementation. + """ + + def default_(self, node: ir.Statement) -> abc.RewriteResult: + return abc.RewriteResult() + + def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult: + rewrite_method_name = f"rewrite_{type(node).__name__}" + rewrite_method = getattr(self, rewrite_method_name, self.default_) + return rewrite_method(node) + + def prep_region(self) -> tuple[ir.Region, ir.Block, ir.SSAValue]: + body = ir.Region(block := ir.Block()) + entry_state = block.args.append_from(StateType, name="entry_state") + return body, block, entry_state + + def construct_execute( + self, + gate_stmt: execute.QuantumStmt, + *, + qubits: tuple[ir.SSAValue, ...], + body: ir.Region, + block: ir.Block, + ) -> execute.ExecuteLowLevel: + block.stmts.append(gate_stmt) + block.stmts.append(execute.ExitLowLevel(state=gate_stmt.state_after)) + + return execute.ExecuteLowLevel(qubits=qubits, body=body) + + def rewrite_CZ(self, node: gate.CZ) -> abc.RewriteResult: + if not isinstance( + targets_list := node.targets.owner, ilist.New + ) or not isinstance(controls_list := node.controls.owner, ilist.New): + return abc.RewriteResult() + + targets = targets_list.values + controls = controls_list.values + if len(targets) != len(controls): + return abc.RewriteResult() + + all_qubits = tuple(range(len(targets) + len(controls))) + n_controls = len(controls) + + body, block, entry_state = self.prep_region() + stmt = circuit.CZ( + entry_state, + controls=all_qubits[:n_controls], + targets=all_qubits[n_controls:], + ) + + node.replace_by( + self.construct_execute( + stmt, qubits=controls + targets, body=body, block=block + ) + ) + + return abc.RewriteResult(has_done_something=True) + + def rewrite_R(self, node: gate.R) -> abc.RewriteResult: + if not isinstance(args_list := node.qubits.owner, ilist.New): + return abc.RewriteResult() + + inputs = args_list.values + + body, block, entry_state = self.prep_region() + gate_stmt = circuit.R( + entry_state, + qubits=tuple(range(len(inputs))), + axis_angle=node.axis_angle, + rotation_angle=node.rotation_angle, + ) + node.replace_by( + self.construct_execute(gate_stmt, qubits=inputs, body=body, block=block) + ) + + return abc.RewriteResult(has_done_something=True) + + def rewrite_Rz(self, node: gate.Rz) -> abc.RewriteResult: + if not isinstance(args_list := node.qubits.owner, ilist.New): + return abc.RewriteResult() + + inputs = args_list.values + + body = ir.Region(block := ir.Block()) + entry_state = block.args.append_from(StateType, name="entry_state") + + gate_stmt = circuit.Rz( + entry_state, + qubits=tuple(range(len(inputs))), + rotation_angle=node.rotation_angle, + ) + + node.replace_by( + self.construct_execute(gate_stmt, qubits=inputs, body=body, block=block) + ) + + return abc.RewriteResult(has_done_something=True) + + +class RewriteConstantToStatic(abc.RewriteRule): + """ + Rewrite rule to convert constant values to static float values. + """ + + def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult: + if not ( + isinstance(node, py.Constant) + and isinstance(value := node.value.unwrap(), float) + ): + return abc.RewriteResult() + + node.replace_by(cpu.StaticFloat(value=value)) + + return abc.RewriteResult(has_done_something=True) + + +class MergePlacementRegions(abc.RewriteRule): + """ + Merge adjacent placement regions into a single region. + This is a placeholder for the actual implementation. + """ + + def remap_qubits( + self, node: circuit.R | circuit.Rz | circuit.CZ, input_map: dict[int, int] + ) -> circuit.R | circuit.Rz | circuit.CZ: + if isinstance(node, circuit.CZ): + return circuit.CZ( + node.state_before, + targets=tuple(input_map[i] for i in node.targets), + controls=tuple(input_map[i] for i in node.controls), + ) + else: + return node.from_stmt( + node, + attributes={ + "qubits": ir.PyAttr(tuple(input_map[i] for i in node.qubits)) + }, + ) + + def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult: + if not ( + isinstance(node, execute.ExecuteLowLevel) + and isinstance(next_node := node.next_stmt, execute.ExecuteLowLevel) + ): + return abc.RewriteResult() + + new_qubits = node.qubits + new_input_map = {} + for old_qid, qbit in enumerate(next_node.qubits): + if qbit not in new_qubits: + new_input_map[old_qid] = len(new_qubits) + new_qubits = new_qubits + (qbit,) + else: + new_input_map[old_qid] = new_qubits.index(qbit) + + new_body = node.body.clone() + + curr_state = None + stmt = (curr_block := new_body.blocks[0]).last_stmt + assert isinstance(stmt, execute.ExitLowLevel) + curr_state = stmt.state + stmt.delete() + + # make sure to copy list of blocks since the loop body will + # mutate the list contained inside of `next_node.body.blocks` + next_block = next_node.body.blocks[0] + stmt = next_block.first_stmt + while stmt: + next_stmt = stmt.next_stmt + stmt.detach() + if isinstance(stmt, (circuit.CZ, circuit.R, circuit.Rz)): + curr_block.stmts.append( + new_stmt := self.remap_qubits(stmt, new_input_map) + ) + curr_state = new_stmt.state_after + elif isinstance(stmt, execute.ExitLowLevel): + curr_block.stmts.append(type(stmt)(state=curr_state)) + else: + curr_block.stmts.append(stmt) + + stmt = next_stmt + + # replace next node with the new merged region + next_node.replace_by( + execute.ExecuteLowLevel( + qubits=new_qubits, + body=new_body, + ) + ) + # delete the node + node.delete() + + return abc.RewriteResult(has_done_something=True) diff --git a/test/rewrite/__init__.py b/test/rewrite/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/rewrite/test_native2placement.py b/test/rewrite/test_native2placement.py new file mode 100644 index 00000000..ff6ebe89 --- /dev/null +++ b/test/rewrite/test_native2placement.py @@ -0,0 +1,185 @@ +from bloqade.native.dialects.gate import stmts as gates +from bloqade.test_utils import assert_nodes +from kirin import ir, rewrite +from kirin.dialects import ilist, py + +from bloqade.lanes import types +from bloqade.lanes.dialects import circuit, execute +from bloqade.lanes.rewrite.native2circuit import ( + MergePlacementRegions, + RewriteLowLevelCircuit, +) + + +def test_cz(): + + test_block = ir.Block( + [ + targets := ilist.New(values=(q0 := ir.TestValue(), q1 := ir.TestValue())), + controls := ilist.New(values=(c0 := ir.TestValue(), c1 := ir.TestValue())), + gates.CZ(targets=targets.result, controls=controls.result), + ], + ) + + expected_block = ir.Block( + [ + targets := ilist.New(values=(q0, q1)), + controls := ilist.New(values=(c0, c1)), + execute.ExecuteLowLevel( + qubits=(c0, c1, q0, q1), body=ir.Region(block := ir.Block()) + ), + ] + ) + + entry_state = block.args.append_from(types.StateType, name="entry_state") + block.stmts.append( + gate_stmt := circuit.CZ(entry_state, controls=(0, 1), targets=(2, 3)) + ) + block.stmts.append(execute.ExitLowLevel(state=gate_stmt.state_after)) + + rule = rewrite.Walk(RewriteLowLevelCircuit()) + + rule.rewrite(test_block) + + assert_nodes(test_block, expected_block) + + +test_cz() + + +def test_r(): + axis_angle = ir.TestValue() + rotation_angle = ir.TestValue() + test_block = ir.Block( + [ + inputs := ilist.New(values=(q0 := ir.TestValue(), q1 := ir.TestValue())), + gates.R( + qubits=inputs.result, + axis_angle=axis_angle, + rotation_angle=rotation_angle, + ), + ], + ) + + expected_block = ir.Block( + [ + inputs := ilist.New(values=(q0, q1)), + execute.ExecuteLowLevel( + qubits=(q0, q1), body=ir.Region(block := ir.Block()) + ), + ] + ) + + entry_state = block.args.append_from(types.StateType, name="entry_state") + block.stmts.append( + gate_stmt := circuit.R( + entry_state, + qubits=(0, 1), + axis_angle=axis_angle, + rotation_angle=rotation_angle, + ) + ) + block.stmts.append(execute.ExitLowLevel(state=gate_stmt.state_after)) + + rule = rewrite.Walk(RewriteLowLevelCircuit()) + + rule.rewrite(test_block) + + assert_nodes(test_block, expected_block) + + +def test_rz(): + rotation_angle = ir.TestValue() + test_block = ir.Block( + [ + qubits := ilist.New(values=(q0 := ir.TestValue(), q1 := ir.TestValue())), + gates.Rz(qubits=qubits.result, rotation_angle=rotation_angle), + ], + ) + + expected_block = ir.Block( + [ + qubits := ilist.New(values=(q0, q1)), + execute.ExecuteLowLevel( + qubits=(q0, q1), body=ir.Region(block := ir.Block()) + ), + ] + ) + + entry_state = block.args.append_from(types.StateType, name="entry_state") + block.stmts.append( + gate_stmt := circuit.Rz( + entry_state, qubits=(0, 1), rotation_angle=rotation_angle + ) + ) + block.stmts.append(execute.ExitLowLevel(state=gate_stmt.state_after)) + + rule = rewrite.Walk(RewriteLowLevelCircuit()) + + rule.rewrite(test_block) + + assert_nodes(test_block, expected_block) + + +def test_merge_regions(): + + qubits = tuple(ir.TestValue() for _ in range(10)) + + test_block = ir.Block( + [ + rotation_angle := py.Constant(0.5), + execute.ExecuteLowLevel( + qubits=(qubits[0], qubits[1]), + body=ir.Region(body_block := ir.Block()), + ), + execute.ExecuteLowLevel( + qubits=(qubits[2], qubits[3]), + body=ir.Region(second_block := ir.Block()), + ), + ] + ) + + entry_state = body_block.args.append_from(types.StateType, name="entry_state") + body_block.stmts.append( + gate_stmt := circuit.Rz( + entry_state, qubits=(0, 1), rotation_angle=rotation_angle.result + ) + ) + body_block.stmts.append(execute.ExitLowLevel(state=gate_stmt.state_after)) + + entry_state = second_block.args.append_from(types.StateType, name="entry_state") + second_block.stmts.append( + gate_stmt := circuit.Rz( + entry_state, qubits=(0, 1), rotation_angle=rotation_angle.result + ) + ) + second_block.stmts.append(execute.ExitLowLevel(state=gate_stmt.state_after)) + + expected_block = ir.Block( + [ + rotation_angle := py.Constant(0.5), + execute.ExecuteLowLevel( + qubits=(qubits[0], qubits[1], qubits[2], qubits[3]), + body=ir.Region(body_block := ir.Block()), + ), + ] + ) + + entry_state = body_block.args.append_from(types.StateType, name="entry_state") + body_block.stmts.append( + ( + gate_stmt := circuit.Rz( + entry_state, qubits=(0, 1), rotation_angle=rotation_angle.result + ) + ) + ) + body_block.stmts.append( + gate_stmt := circuit.Rz( + gate_stmt.state_after, qubits=(2, 3), rotation_angle=rotation_angle.result + ) + ) + body_block.stmts.append(execute.ExitLowLevel(state=gate_stmt.state_after)) + + rewrite.Walk(MergePlacementRegions()).rewrite(test_block) + + assert_nodes(test_block, expected_block)