diff --git a/src/bloqade/squin/rewrite/U3_to_clifford.py b/src/bloqade/squin/rewrite/U3_to_clifford.py index 7e5ac1cc..50458233 100644 --- a/src/bloqade/squin/rewrite/U3_to_clifford.py +++ b/src/bloqade/squin/rewrite/U3_to_clifford.py @@ -15,6 +15,14 @@ class Sdag(ir.Statement): pass +class SqrtXdag(ir.Statement): + pass + + +class SqrtYdag(ir.Statement): + pass + + # (theta, phi, lam) U3_HALF_PI_ANGLE_TO_GATES: dict[ tuple[int, int, int], list[type[ir.Statement]] | list[None] @@ -27,21 +35,21 @@ class Sdag(ir.Statement): (1, 0, 1): [gate.stmts.S, gate.stmts.SqrtY], (1, 0, 2): [gate.stmts.H], (1, 0, 3): [Sdag, gate.stmts.SqrtY], - (1, 1, 0): [gate.stmts.SqrtY, gate.stmts.S], - (1, 1, 1): [gate.stmts.S, gate.stmts.SqrtY, gate.stmts.S], - (1, 1, 2): [gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S], - (1, 1, 3): [Sdag, gate.stmts.SqrtY, gate.stmts.S], - (1, 2, 0): [gate.stmts.SqrtY, gate.stmts.Z], - (1, 2, 1): [gate.stmts.S, gate.stmts.SqrtY, gate.stmts.Z], - (1, 2, 2): [gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.Z], - (1, 2, 3): [Sdag, gate.stmts.SqrtY, gate.stmts.Z], - (1, 3, 0): [gate.stmts.SqrtY, Sdag], - (1, 3, 1): [gate.stmts.S, gate.stmts.SqrtY, Sdag], - (1, 3, 2): [gate.stmts.Z, gate.stmts.SqrtY, Sdag], - (1, 3, 3): [Sdag, gate.stmts.SqrtY, Sdag], + (1, 1, 0): [gate.stmts.S, SqrtXdag], + (1, 1, 1): [gate.stmts.Z, SqrtXdag], + (1, 1, 2): [Sdag, SqrtXdag], + (1, 1, 3): [SqrtXdag], + (1, 2, 0): [gate.stmts.Z, SqrtYdag], + (1, 2, 1): [Sdag, SqrtYdag], + (1, 2, 2): [SqrtYdag], + (1, 2, 3): [gate.stmts.S, SqrtYdag], + (1, 3, 0): [Sdag, gate.stmts.SqrtX], + (1, 3, 1): [gate.stmts.SqrtX], + (1, 3, 2): [gate.stmts.S, gate.stmts.SqrtX], + (1, 3, 3): [gate.stmts.Z, gate.stmts.SqrtX], (2, 0, 0): [gate.stmts.Y], (2, 0, 1): [gate.stmts.S, gate.stmts.Y], - (2, 0, 2): [gate.stmts.Z, gate.stmts.Y], + (2, 0, 2): [gate.stmts.X], (2, 0, 3): [Sdag, gate.stmts.Y], } @@ -106,6 +114,10 @@ def rewrite_U3(self, node: gate.stmts.U3) -> RewriteResult: for gate_stmt in gates: if gate_stmt is Sdag: new_stmt = gate.stmts.S(adjoint=True, qubits=node.qubits) + elif gate_stmt is SqrtXdag: + new_stmt = gate.stmts.SqrtX(adjoint=True, qubits=node.qubits) + elif gate_stmt is SqrtYdag: + new_stmt = gate.stmts.SqrtY(adjoint=True, qubits=node.qubits) else: new_stmt = gate_stmt(qubits=node.qubits) new_stmt.insert_before(node) diff --git a/test/squin/rewrite/test_U3_to_clifford.py b/test/squin/rewrite/test_U3_to_clifford.py index 0aebab77..08a52bed 100644 --- a/test/squin/rewrite/test_U3_to_clifford.py +++ b/test/squin/rewrite/test_U3_to_clifford.py @@ -231,7 +231,7 @@ def test(): assert not sqrt_y_stmt.adjoint -def test_sqrt_y_s(): +def test_s_sqrt_x_dag(): @sq.kernel def test(): @@ -246,22 +246,22 @@ def test(): SquinToCliffordTestPass(test.dialects)(test) test.print() - assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.S) - assert isinstance(get_stmt_at_idx(test, 8), gate.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.SqrtX) + assert isinstance(get_stmt_at_idx(test, 8), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.SqrtX) - sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,)) + sqrt_x_stmts = filter_statements_by_type(test, (gate.stmts.SqrtX,)) s_stmts = filter_statements_by_type(test, (gate.stmts.S,)) - for sqrt_y_stmt in sqrt_y_stmts: - assert not sqrt_y_stmt.adjoint + for sqrt_x_stmt in sqrt_x_stmts: + assert sqrt_x_stmt.adjoint for s_stmt in s_stmts: assert not s_stmt.adjoint -def test_s_sqrt_y_s(): +def test_z_sqrt_x_dag(): @sq.kernel def test(): @@ -277,29 +277,28 @@ def test(): SquinToCliffordTestPass(test.dialects)(test) s_stmts = filter_statements_by_type(test, (gate.stmts.S,)) - sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,)) + sqrt_x_stmts = filter_statements_by_type(test, (gate.stmts.SqrtX,)) - # Should be S, SqrtY, S for each op assert [ type(stmt) - for stmt in filter_statements_by_type(test, (gate.stmts.S, gate.stmts.SqrtY)) + for stmt in filter_statements_by_type( + test, (gate.stmts.S, gate.stmts.Z, gate.stmts.SqrtX, gate.stmts.SqrtY) + ) ] == [ - gate.stmts.S, - gate.stmts.SqrtY, - gate.stmts.S, - gate.stmts.S, - gate.stmts.SqrtY, - gate.stmts.S, + gate.stmts.Z, + gate.stmts.SqrtX, + gate.stmts.Z, + gate.stmts.SqrtX, ] # Check adjoint property for s_stmt in s_stmts: assert not s_stmt.adjoint - for sqrt_y_stmt in sqrt_y_stmts: - assert not sqrt_y_stmt.adjoint + for sqrt_x_stmt in sqrt_x_stmts: + assert sqrt_x_stmt.adjoint -def test_z_sqrt_y_s(): +def test_s_dag_sqrt_x_dag(): @sq.kernel def test(): @@ -316,25 +315,23 @@ def test(): test.print() relevant_stmts = filter_statements_by_type( - test, (gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S) + test, (gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S, gate.stmts.SqrtX) ) expected_types = [ - gate.stmts.Z, - gate.stmts.SqrtY, gate.stmts.S, - gate.stmts.Z, - gate.stmts.SqrtY, + gate.stmts.SqrtX, gate.stmts.S, + gate.stmts.SqrtX, ] assert [type(stmt) for stmt in relevant_stmts] == expected_types for relevant_stmt in relevant_stmts: if type(relevant_stmt) is not gate.stmts.Z: - assert not relevant_stmt.adjoint + assert relevant_stmt.adjoint -def test_sdg_sqrt_y_s(): +def test_sqrt_x_dag(): @sq.kernel def test(): @@ -349,29 +346,24 @@ def test(): SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type(test, (gate.stmts.S, gate.stmts.SqrtY)) + relevant_stmts = filter_statements_by_type( + test, (gate.stmts.S, gate.stmts.SqrtY, gate.stmts.SqrtX) + ) - # Should be Sdg, SqrtY, S for each op assert [type(stmt) for stmt in relevant_stmts] == [ - gate.stmts.S, - gate.stmts.SqrtY, - gate.stmts.S, - gate.stmts.S, - gate.stmts.SqrtY, - gate.stmts.S, + gate.stmts.SqrtX, + gate.stmts.SqrtX, ] - # Check adjoint property: the first S in each group should be adjoint s_stmts = filter_statements_by_type(test, (gate.stmts.S,)) - sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,)) + sqrt_x_stmts = filter_statements_by_type(test, (gate.stmts.SqrtX,)) - assert s_stmts[0].adjoint - assert s_stmts[2].adjoint - for sqrt_y_stmt in sqrt_y_stmts: - assert not sqrt_y_stmt.adjoint + assert not s_stmts + for sqrt_x_stmt in sqrt_x_stmts: + assert sqrt_x_stmt.adjoint -def test_sqrt_y_z(): +def test_z_sqrt_y_dag(): @sq.kernel def test(): @@ -385,17 +377,17 @@ def test(): SquinToCliffordTestPass(test.dialects)(test) test.print() - assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.Z) - assert isinstance(get_stmt_at_idx(test, 8), gate.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 8), gate.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.SqrtY) sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,)) for sqrt_y_stmt in sqrt_y_stmts: - assert not sqrt_y_stmt.adjoint + assert sqrt_y_stmt.adjoint -def test_s_sqrt_y_z(): +def test_s_dag_sqrt_y_dag(): @sq.kernel def test(): @@ -417,18 +409,16 @@ def test(): assert [type(stmt) for stmt in relevant_stmts] == [ gate.stmts.S, gate.stmts.SqrtY, - gate.stmts.Z, gate.stmts.S, gate.stmts.SqrtY, - gate.stmts.Z, ] for stmt in relevant_stmts: if type(stmt) is not gate.stmts.Z: - assert not stmt.adjoint + assert stmt.adjoint -def test_z_sqrt_y_z(): +def test_sqrt_y_dag(): @sq.kernel def test(): @@ -444,21 +434,17 @@ def test(): relevant_stmts = filter_statements_by_type(test, (gate.stmts.Z, gate.stmts.SqrtY)) expected_types = [ - gate.stmts.Z, gate.stmts.SqrtY, - gate.stmts.Z, - gate.stmts.Z, gate.stmts.SqrtY, - gate.stmts.Z, ] assert [type(stmt) for stmt in relevant_stmts] == expected_types for stmt in relevant_stmts: if type(stmt) is gate.stmts.SqrtY: - assert not stmt.adjoint + assert stmt.adjoint -def test_sdg_sqrt_y_z(): +def test_s_sqrt_y_dag(): @sq.kernel def test(): @@ -477,28 +463,24 @@ def test(): test, (gate.stmts.S, gate.stmts.SqrtY, gate.stmts.Z) ) - # Should be Sdag, SqrtY, Z for each op assert [type(stmt) for stmt in relevant_stmts] == [ gate.stmts.S, gate.stmts.SqrtY, - gate.stmts.Z, gate.stmts.S, gate.stmts.SqrtY, - gate.stmts.Z, ] - # Check adjoint property: Sdag should be adjoint, SqrtY and Z should not s_stmts = filter_statements_by_type(test, (gate.stmts.S,)) sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,)) for s_stmt in s_stmts: - assert s_stmt.adjoint + assert not s_stmt.adjoint for sqrt_y_stmt in sqrt_y_stmts: - assert not sqrt_y_stmt.adjoint + assert sqrt_y_stmt.adjoint -def test_sqrt_y_sdg(): +def test_s_dag_sqrt_x(): @sq.kernel def test(): @@ -510,18 +492,19 @@ def test(): SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY, gate.stmts.S)) - # Check for SqrtY followed by S (adjoint property can be checked if needed) + relevant_stmts = filter_statements_by_type( + test, (gate.stmts.SqrtY, gate.stmts.SqrtX, gate.stmts.S) + ) assert [type(stmt) for stmt in relevant_stmts] == [ - gate.stmts.SqrtY, gate.stmts.S, + gate.stmts.SqrtX, ] - assert not relevant_stmts[0].adjoint - assert relevant_stmts[1].adjoint + assert relevant_stmts[0].adjoint + assert not relevant_stmts[1].adjoint -def test_s_sqrt_y_sdg(): +def test_sqrt_x(): @sq.kernel def test(): @@ -532,20 +515,17 @@ def test(): ) SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type(test, (gate.stmts.S, gate.stmts.SqrtY)) + relevant_stmts = filter_statements_by_type( + test, (gate.stmts.S, gate.stmts.SqrtY, gate.stmts.SqrtX) + ) assert [type(stmt) for stmt in relevant_stmts] == [ - gate.stmts.S, - gate.stmts.SqrtY, - gate.stmts.S, + gate.stmts.SqrtX, ] - # The last S should be adjoint assert not relevant_stmts[0].adjoint - assert not relevant_stmts[1].adjoint - assert relevant_stmts[2].adjoint -def test_z_sqrt_y_sdg(): +def test_s_sqrt_x(): @sq.kernel def test(): @@ -558,19 +538,17 @@ def test(): SquinToCliffordTestPass(test.dialects)(test) relevant_stmts = filter_statements_by_type( - test, (gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S) + test, (gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S, gate.stmts.SqrtX) ) - # Should be Z, SqrtY, S (adjoint) assert [type(stmt) for stmt in relevant_stmts] == [ - gate.stmts.Z, - gate.stmts.SqrtY, gate.stmts.S, + gate.stmts.SqrtX, ] + assert not relevant_stmts[0].adjoint assert not relevant_stmts[1].adjoint - assert relevant_stmts[2].adjoint -def test_sdg_sqrt_y_sdg(): +def test_z_sqrt_x(): @sq.kernel def test(): @@ -582,18 +560,15 @@ def test(): SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type(test, (gate.stmts.S, gate.stmts.SqrtY)) + relevant_stmts = filter_statements_by_type( + test, (gate.stmts.S, gate.stmts.SqrtY, gate.stmts.Z, gate.stmts.SqrtX) + ) - # Should be Sdag, SqrtY, Sdag for the op assert [type(stmt) for stmt in relevant_stmts] == [ - gate.stmts.S, - gate.stmts.SqrtY, - gate.stmts.S, + gate.stmts.Z, + gate.stmts.SqrtX, ] - # The first and last S should be adjoint, SqrtY should not - assert relevant_stmts[0].adjoint assert not relevant_stmts[1].adjoint - assert relevant_stmts[2].adjoint def test_y(): @@ -625,7 +600,7 @@ def test(): assert not s_stmt.adjoint -def test_z_y(): +def test_x(): @sq.kernel def test(): @@ -634,8 +609,7 @@ def test(): sq.u3(theta=0.5 * math.tau, phi=0.0 * math.tau, lam=0.5 * math.tau, qubit=q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.Z) - assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.Y) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.X) def test_sdg_y(): diff --git a/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim b/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim index 691ee7fc..4764563d 100644 --- a/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim +++ b/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim @@ -1,10 +1,5 @@ -S 0 -SQRT_Y 0 -S 0 -S_DAG 0 -SQRT_Y 0 -S 0 -S 0 -SQRT_Y 0 -S_DAG 0 +Z 0 +SQRT_X_DAG 0 +SQRT_X_DAG 0 +SQRT_X 0 diff --git a/test/stim/passes/test_squin_qubit_to_stim.py b/test/stim/passes/test_squin_qubit_to_stim.py index eafd4c3e..84a90928 100644 --- a/test/stim/passes/test_squin_qubit_to_stim.py +++ b/test/stim/passes/test_squin_qubit_to_stim.py @@ -173,7 +173,7 @@ def test_u3_rewrite(): def test(): q = sq.qalloc(1) - sq.u3(-pi / 2, -pi / 2, -pi / 2, q[0]) # S @ SQRT_Y @ S = Z @ SQRT_X + sq.u3(-pi / 2, -pi / 2, -pi / 2, q[0]) # S @ SQRT_Y @ S = SQRT_X_DAG @ Z sq.u3(-pi / 2, -pi / 2, pi / 2, q[0]) # S @ SQRT_Y @ S_DAG = SQRT_X_DAG sq.u3(-pi / 2, pi / 2, -pi / 2, q[0]) # S_DAG @ SQRT_Y @ S = SQRT_X return