|
| 1 | +# create rewrite rule name SquinMeasureToStim using kirin |
| 2 | +import math |
| 3 | +from typing import List, Tuple, Callable |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from kirin import ir |
| 7 | +from kirin.dialects import py |
| 8 | +from kirin.rewrite.abc import RewriteRule, RewriteResult |
| 9 | + |
| 10 | +from bloqade.squin import op, qubit |
| 11 | + |
| 12 | + |
| 13 | +def sdag() -> list[ir.Statement]: |
| 14 | + return [_op := op.stmts.S(), op.stmts.Adjoint(op=_op.result, is_unitary=True)] |
| 15 | + |
| 16 | + |
| 17 | +# (theta, phi, lam) |
| 18 | +U3_HALF_PI_ANGLE_TO_GATES: dict[ |
| 19 | + tuple[int, int, int], Callable[[], Tuple[List[ir.Statement], ...]] |
| 20 | +] = { |
| 21 | + (0, 0, 0): lambda: ([op.stmts.Identity(sites=1)],), |
| 22 | + (0, 0, 1): lambda: ([op.stmts.S()],), |
| 23 | + (0, 0, 2): lambda: ([op.stmts.Z()],), |
| 24 | + (0, 0, 3): lambda: (sdag(),), |
| 25 | + (1, 0, 0): lambda: ([op.stmts.SqrtY()],), |
| 26 | + (1, 0, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()]), |
| 27 | + (1, 0, 2): lambda: ([op.stmts.H()],), |
| 28 | + (1, 0, 3): lambda: (sdag(), [op.stmts.SqrtY()]), |
| 29 | + (1, 1, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.S()]), |
| 30 | + (1, 1, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]), |
| 31 | + (1, 1, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]), |
| 32 | + (1, 1, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]), |
| 33 | + (1, 2, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.Z()]), |
| 34 | + (1, 2, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]), |
| 35 | + (1, 2, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]), |
| 36 | + (1, 2, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]), |
| 37 | + (1, 3, 0): lambda: ([op.stmts.SqrtY()], sdag()), |
| 38 | + (1, 3, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()), |
| 39 | + (1, 3, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()), |
| 40 | + (1, 3, 3): lambda: (sdag(), [op.stmts.SqrtY()], sdag()), |
| 41 | + (2, 0, 0): lambda: ([op.stmts.Y()],), |
| 42 | + (2, 0, 1): lambda: ([op.stmts.S()], [op.stmts.Y()]), |
| 43 | + (2, 0, 2): lambda: ([op.stmts.Z()], [op.stmts.Y()]), |
| 44 | + (2, 0, 3): lambda: (sdag(), [op.stmts.Y()]), |
| 45 | +} |
| 46 | + |
| 47 | + |
| 48 | +def equivalent_u3_para( |
| 49 | + theta_half_pi: int, phi_half_pi: int, lam_half_pi: int |
| 50 | +) -> tuple[int, int, int]: |
| 51 | + """ |
| 52 | + 1. Assume all three angles are in the range [0, 4]. |
| 53 | + 2. U3(theta, phi, lam) = -U3(2pi-theta, phi+pi, lam+pi). |
| 54 | + """ |
| 55 | + return ((4 - theta_half_pi) % 4, (phi_half_pi + 2) % 4, (lam_half_pi + 2) % 4) |
| 56 | + |
| 57 | + |
| 58 | +class SquinU3ToClifford(RewriteRule): |
| 59 | + """ |
| 60 | + Rewrite squin U3 statements to clifford when possible. |
| 61 | + """ |
| 62 | + |
| 63 | + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: |
| 64 | + if isinstance(node, (qubit.Apply, qubit.Broadcast)): |
| 65 | + return self.rewrite_ApplyOrBroadcast_onU3(node) |
| 66 | + else: |
| 67 | + return RewriteResult() |
| 68 | + |
| 69 | + def get_constant(self, node: ir.SSAValue) -> float | None: |
| 70 | + if isinstance(node.owner, py.Constant): |
| 71 | + # node.value is a PyAttr, need to get the wrapped value out |
| 72 | + return node.owner.value.unwrap() |
| 73 | + else: |
| 74 | + return None |
| 75 | + |
| 76 | + def resolve_angle(self, angle: float) -> int | None: |
| 77 | + """ |
| 78 | + Normalize the angle to be in the range [0, 2π). |
| 79 | + """ |
| 80 | + # convert to 0.0~1.0, in unit of pi/2 |
| 81 | + angle_half_pi = angle / math.pi * 2.0 |
| 82 | + |
| 83 | + mod = angle_half_pi % 1.0 |
| 84 | + if not (np.isclose(mod, 0.0) or np.isclose(mod, 1.0)): |
| 85 | + return None |
| 86 | + |
| 87 | + else: |
| 88 | + return round((angle / math.tau) % 1 * 4) % 4 |
| 89 | + |
| 90 | + def rewrite_ApplyOrBroadcast_onU3( |
| 91 | + self, node: qubit.Apply | qubit.Broadcast |
| 92 | + ) -> RewriteResult: |
| 93 | + """ |
| 94 | + Rewrite Apply and Broadcast nodes to their clifford equivalent statements. |
| 95 | + """ |
| 96 | + if not isinstance(node.operator.owner, op.stmts.U3): |
| 97 | + return RewriteResult() |
| 98 | + |
| 99 | + gates = self.decompose_U3_gates(node.operator.owner) |
| 100 | + |
| 101 | + if len(gates) == 0: |
| 102 | + return RewriteResult() |
| 103 | + |
| 104 | + for stmt_list in gates: |
| 105 | + for gate_stmt in stmt_list[:-1]: |
| 106 | + gate_stmt.insert_before(node) |
| 107 | + |
| 108 | + oper = stmt_list[-1] |
| 109 | + oper.insert_before(node) |
| 110 | + new_node = node.__class__(operator=oper.result, qubits=node.qubits) |
| 111 | + new_node.insert_before(node) |
| 112 | + |
| 113 | + node.delete() |
| 114 | + |
| 115 | + # rewrite U3 to clifford gates |
| 116 | + return RewriteResult(has_done_something=True) |
| 117 | + |
| 118 | + def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...]: |
| 119 | + """ |
| 120 | + Rewrite U3 statements to clifford gates if possible. |
| 121 | + """ |
| 122 | + theta = self.get_constant(node.theta) |
| 123 | + phi = self.get_constant(node.phi) |
| 124 | + lam = self.get_constant(node.lam) |
| 125 | + |
| 126 | + if theta is None or phi is None or lam is None: |
| 127 | + return () |
| 128 | + |
| 129 | + theta_half_pi: int | None = self.resolve_angle(theta) |
| 130 | + phi_half_pi: int | None = self.resolve_angle(phi) |
| 131 | + lam_half_pi: int | None = self.resolve_angle(lam) |
| 132 | + |
| 133 | + if theta_half_pi is None or phi_half_pi is None or lam_half_pi is None: |
| 134 | + return () |
| 135 | + |
| 136 | + angles_key = (theta_half_pi, phi_half_pi, lam_half_pi) |
| 137 | + if angles_key not in U3_HALF_PI_ANGLE_TO_GATES: |
| 138 | + angles_key = equivalent_u3_para(*angles_key) |
| 139 | + if angles_key not in U3_HALF_PI_ANGLE_TO_GATES: |
| 140 | + return () |
| 141 | + |
| 142 | + gates_stmts = U3_HALF_PI_ANGLE_TO_GATES.get(angles_key) |
| 143 | + |
| 144 | + # no consistent gates, then: |
| 145 | + assert ( |
| 146 | + gates_stmts is not None |
| 147 | + ), "internal error, U3 gates not found for angles: {}".format(angles_key) |
| 148 | + |
| 149 | + return gates_stmts() |
0 commit comments