From 8cade008e76201148b28f0c61cdab998e64b5216 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 5 Nov 2025 15:19:40 +0100 Subject: [PATCH] Revert "Backport cleaner squin gate generation (#598)" This reverts commit a6626eba19e3f6adaf7b006e00cf7b0fceee895f. --- src/bloqade/squin/rewrite/U3_to_clifford.py | 40 +-- test/squin/rewrite/test_U3_to_clifford.py | 316 +++++++++++------- .../qubit/u3_gates.stim | 5 - .../qubit/u3_to_clifford.stim | 9 + test/stim/passes/test_squin_qubit_to_stim.py | 57 +--- 5 files changed, 227 insertions(+), 200 deletions(-) delete mode 100644 test/stim/passes/stim_reference_programs/qubit/u3_gates.stim diff --git a/src/bloqade/squin/rewrite/U3_to_clifford.py b/src/bloqade/squin/rewrite/U3_to_clifford.py index e353594b8..bc3c4f63b 100644 --- a/src/bloqade/squin/rewrite/U3_to_clifford.py +++ b/src/bloqade/squin/rewrite/U3_to_clifford.py @@ -14,20 +14,6 @@ def sdag() -> list[ir.Statement]: return [_op := op.stmts.S(), op.stmts.Adjoint(op=_op.result, is_unitary=True)] -def sqrt_x_dag() -> list[ir.Statement]: - return [ - _op := op.stmts.SqrtX(), - op.stmts.Adjoint(op=_op.result, is_unitary=True), - ] - - -def sqrt_y_dag() -> list[ir.Statement]: - return [ - _op := op.stmts.SqrtY(), - op.stmts.Adjoint(op=_op.result, is_unitary=True), - ] - - # (theta, phi, lam) U3_HALF_PI_ANGLE_TO_GATES: dict[ tuple[int, int, int], Callable[[], Tuple[List[ir.Statement], ...]] @@ -40,21 +26,21 @@ def sqrt_y_dag() -> list[ir.Statement]: (1, 0, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()]), (1, 0, 2): lambda: ([op.stmts.H()],), (1, 0, 3): lambda: (sdag(), [op.stmts.SqrtY()]), - (1, 1, 0): lambda: ([op.stmts.S()], sqrt_x_dag()), - (1, 1, 1): lambda: ([op.stmts.Z()], sqrt_x_dag()), - (1, 1, 2): lambda: (sdag(), sqrt_x_dag()), - (1, 1, 3): lambda: (sqrt_x_dag(),), - (1, 2, 0): lambda: ([op.stmts.Z()], sqrt_y_dag()), - (1, 2, 1): lambda: (sdag(), sqrt_y_dag()), - (1, 2, 2): lambda: (sqrt_y_dag(),), - (1, 2, 3): lambda: ([op.stmts.S()], sqrt_y_dag()), - (1, 3, 0): lambda: (sdag(), [op.stmts.SqrtX()]), - (1, 3, 1): lambda: ([op.stmts.SqrtX()],), - (1, 3, 2): lambda: ([op.stmts.S()], [op.stmts.SqrtX()]), - (1, 3, 3): lambda: ([op.stmts.Z()], [op.stmts.SqrtX()]), + (1, 1, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.S()]), + (1, 1, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]), + (1, 1, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]), + (1, 1, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]), + (1, 2, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.Z()]), + (1, 2, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]), + (1, 2, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]), + (1, 2, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]), + (1, 3, 0): lambda: ([op.stmts.SqrtY()], sdag()), + (1, 3, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()), + (1, 3, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()), + (1, 3, 3): lambda: (sdag(), [op.stmts.SqrtY()], sdag()), (2, 0, 0): lambda: ([op.stmts.Y()],), (2, 0, 1): lambda: ([op.stmts.S()], [op.stmts.Y()]), - (2, 0, 2): lambda: ([op.stmts.X()],), + (2, 0, 2): lambda: ([op.stmts.Z()], [op.stmts.Y()]), (2, 0, 3): lambda: (sdag(), [op.stmts.Y()]), } diff --git a/test/squin/rewrite/test_U3_to_clifford.py b/test/squin/rewrite/test_U3_to_clifford.py index 06030120b..6076c1cf1 100644 --- a/test/squin/rewrite/test_U3_to_clifford.py +++ b/test/squin/rewrite/test_U3_to_clifford.py @@ -17,6 +17,10 @@ def unsafe_run(self, mt: ir.Method): ).rewrite(mt.code) +def get_stmt_at_idx(method: ir.Method, idx: int) -> ir.Statement: + return method.callable_region.blocks[0].stmts.at(idx) + + def filter_statements_by_type(method: ir.Method, types: tuple[type, ...]) -> list[type]: return [ type(stmt) @@ -35,9 +39,8 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.Identity] - assert filtered_stmts == expected_stmts + + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Identity) def test_s(): @@ -49,9 +52,7 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S] - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) # exercise equivalent_u3_para check ## assumes it's already in units of half pi and normalized to [0, 1) @@ -62,9 +63,8 @@ def test_equiv(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test_equiv.dialects)(test_equiv) - filtered_stmts = filter_statements_by_type(test_equiv, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S] - assert filtered_stmts == expected_stmts + + assert isinstance(get_stmt_at_idx(test_equiv, 4), op.stmts.S) def test_s_alternative(): @@ -76,9 +76,7 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S] - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) def test_z(): @@ -97,9 +95,10 @@ def test(): qubit.apply(op2, q[2]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.Z] * 3 - assert filtered_stmts == expected_stmts + + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 8), op.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 12), op.stmts.Z) def test_z_alternative(): @@ -111,9 +110,7 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.Z] - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Z) def test_sdag(): @@ -125,9 +122,8 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.Adjoint] - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) + assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) @kernel def test_equiv(): @@ -136,9 +132,8 @@ def test_equiv(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test_equiv.dialects)(test_equiv) - filtered_stmts = filter_statements_by_type(test_equiv, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.Adjoint] - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test_equiv, 4), op.stmts.S) + assert isinstance(get_stmt_at_idx(test_equiv, 5), op.stmts.Adjoint) def test_sdag_alternative_negative(): @@ -150,9 +145,8 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.Adjoint] - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) + assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) def test_sdag_alternative(): @@ -164,9 +158,8 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.Adjoint] - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) + assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) def test_sdag_weird_case(): @@ -178,9 +171,8 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.Adjoint] - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) + assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) def test_sdag_weirder_case(): @@ -192,11 +184,8 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.Adjoint] - # Technically a Y afterwards, just want to check the first two - # stmts are S + Adjoint - assert filtered_stmts[:-1] == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) + assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) def test_sqrt_y(): @@ -212,9 +201,9 @@ def test(): qubit.apply(op1, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.SqrtY] * 2 - assert filtered_stmts == expected_stmts + + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 8), op.stmts.SqrtY) def test_s_sqrt_y(): @@ -230,9 +219,10 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.SqrtY, op.stmts.S, op.stmts.SqrtY] - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) + assert isinstance(get_stmt_at_idx(test, 6), op.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 10), op.stmts.S) + assert isinstance(get_stmt_at_idx(test, 12), op.stmts.SqrtY) def test_h(): @@ -247,9 +237,8 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.H] * 2 - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.H) + assert isinstance(get_stmt_at_idx(test, 8), op.stmts.H) def test_sdg_sqrt_y(): @@ -265,12 +254,15 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.Adjoint, op.stmts.SqrtY] * 2 - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) + assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) + assert isinstance(get_stmt_at_idx(test, 7), op.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 11), op.stmts.S) + assert isinstance(get_stmt_at_idx(test, 12), op.stmts.Adjoint) + assert isinstance(get_stmt_at_idx(test, 14), op.stmts.SqrtY) -def test_s_sqrt_x_dag(): +def test_sqrt_y_s(): @kernel def test(): @@ -282,12 +274,10 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.SqrtX, op.stmts.Adjoint] * 2 - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.SqrtY) -def test_z_sqrt_x_dag(): +def test_s_sqrt_y_s(): @kernel def test(): @@ -300,12 +290,19 @@ def test(): SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.Z, op.stmts.SqrtX, op.stmts.Adjoint] * 2 - assert filtered_stmts == expected_stmts + relevant_stmts = filter_statements_by_type(test, (op.stmts.S, op.stmts.SqrtY)) + + assert relevant_stmts == [ + op.stmts.S, + op.stmts.SqrtY, + op.stmts.S, + op.stmts.S, + op.stmts.SqrtY, + op.stmts.S, + ] -def test_s_dag_sqrt_x_dag(): +def test_z_sqrt_y_s(): @kernel def test(): @@ -317,17 +314,21 @@ def test(): qubit.apply(op1, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [ + + relevant_stmts = filter_statements_by_type( + test, (op.stmts.Z, op.stmts.SqrtY, op.stmts.S) + ) + assert relevant_stmts == [ + op.stmts.Z, + op.stmts.SqrtY, op.stmts.S, - op.stmts.Adjoint, - op.stmts.SqrtX, - op.stmts.Adjoint, - ] * 2 - assert filtered_stmts == expected_stmts + op.stmts.Z, + op.stmts.SqrtY, + op.stmts.S, + ] -def test_sqrt_x_dag(): +def test_sdg_sqrt_y_s(): @kernel def test(): @@ -340,12 +341,24 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.SqrtX, op.stmts.Adjoint] * 2 - assert filtered_stmts == expected_stmts + relevant_stmts = filter_statements_by_type( + test, (op.stmts.S, op.stmts.Adjoint, op.stmts.SqrtY) + ) + + assert relevant_stmts == [ + op.stmts.S, + op.stmts.Adjoint, + op.stmts.SqrtY, + op.stmts.S, + op.stmts.S, + op.stmts.Adjoint, + op.stmts.SqrtY, + op.stmts.S, + ] -def test_z_sqrt_y_dag(): + +def test_sqrt_y_z(): @kernel def test(): @@ -358,12 +371,14 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.Z, op.stmts.SqrtY, op.stmts.Adjoint] * 2 - assert filtered_stmts == expected_stmts + + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 6), op.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 10), op.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 12), op.stmts.Z) -def test_s_dag_sqrt_y_dag(): +def test_s_sqrt_y_z(): @kernel def test(): @@ -376,17 +391,22 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [ + + relevant_stmts = filter_statements_by_type( + test, (op.stmts.S, op.stmts.SqrtY, op.stmts.Z) + ) + + assert relevant_stmts == [ op.stmts.S, - op.stmts.Adjoint, op.stmts.SqrtY, - op.stmts.Adjoint, - ] * 2 - assert filtered_stmts == expected_stmts + op.stmts.Z, + op.stmts.S, + op.stmts.SqrtY, + op.stmts.Z, + ] -def test_sqrt_y_dag(): +def test_z_sqrt_y_z(): @kernel def test(): @@ -400,12 +420,19 @@ def test(): SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.SqrtY, op.stmts.Adjoint] * 2 - assert filtered_stmts == expected_stmts + relevant_stmts = filter_statements_by_type(test, (op.stmts.Z, op.stmts.SqrtY)) + + assert relevant_stmts == [ + op.stmts.Z, + op.stmts.SqrtY, + op.stmts.Z, + op.stmts.Z, + op.stmts.SqrtY, + op.stmts.Z, + ] -def test_s_sqrt_y_dag(): +def test_sdg_sqrt_y_z(): @kernel def test(): @@ -418,57 +445,88 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.SqrtY, op.stmts.Adjoint] * 2 - assert filtered_stmts == expected_stmts + + relevant_stmts = filter_statements_by_type( + test, (op.stmts.S, op.stmts.Adjoint, op.stmts.SqrtY, op.stmts.Z) + ) + assert relevant_stmts == [ + op.stmts.S, + op.stmts.Adjoint, + op.stmts.SqrtY, + op.stmts.Z, + op.stmts.S, + op.stmts.Adjoint, + op.stmts.SqrtY, + op.stmts.Z, + ] -def test_s_dag_sqrt_x(): +def test_sqrt_y_sdg(): @kernel def test(): q = qubit.new(4) # (1, 3, 0) - u3 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.0 * math.tau) - qubit.apply(u3, q[0]) + op0 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.0 * math.tau) + qubit.apply(op0, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.Adjoint, op.stmts.SqrtX] - assert filtered_stmts == expected_stmts + + relevant_stmts = filter_statements_by_type( + test, (op.stmts.SqrtY, op.stmts.S, op.stmts.Adjoint) + ) + assert relevant_stmts == [ + op.stmts.SqrtY, + op.stmts.S, + op.stmts.Adjoint, + ] -def test_sqrt_x(): +def test_s_sqrt_y_sdg(): @kernel def test(): q = qubit.new(4) # (1, 3, 1) - u3 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.25 * math.tau) - qubit.apply(u3, q[0]) + op0 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.25 * math.tau) + qubit.apply(op0, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.SqrtX] - assert filtered_stmts == expected_stmts + relevant_stmts = filter_statements_by_type( + test, (op.stmts.SqrtY, op.stmts.S, op.stmts.Adjoint) + ) + + assert relevant_stmts == [ + op.stmts.S, + op.stmts.SqrtY, + op.stmts.S, + op.stmts.Adjoint, + ] -def test_s_sqrt_x(): +def test_z_sqrt_y_sdg(): @kernel def test(): q = qubit.new(4) # (1, 3, 2) - u3 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.5 * math.tau) - qubit.apply(u3, q[0]) + op0 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.5 * math.tau) + qubit.apply(op0, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.SqrtX] - assert filtered_stmts == expected_stmts + + relevant_stmts = filter_statements_by_type( + test, (op.stmts.Z, op.stmts.SqrtY, op.stmts.S, op.stmts.Adjoint) + ) + assert relevant_stmts == [ + op.stmts.Z, + op.stmts.SqrtY, + op.stmts.S, + op.stmts.Adjoint, + ] -def test_z_sqrt_x(): +def test_sdg_sqrt_y_sdg(): @kernel def test(): @@ -478,9 +536,17 @@ def test(): qubit.apply(op0, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.Z, op.stmts.SqrtX] - assert filtered_stmts == expected_stmts + + relevant_stmts = filter_statements_by_type( + test, (op.stmts.S, op.stmts.Adjoint, op.stmts.SqrtY) + ) + assert relevant_stmts == [ + op.stmts.S, + op.stmts.Adjoint, + op.stmts.SqrtY, + op.stmts.S, + op.stmts.Adjoint, + ] def test_y(): @@ -493,9 +559,7 @@ def test(): qubit.apply(op0, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.Y] - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Y) def test_s_y(): @@ -508,12 +572,11 @@ def test(): qubit.apply(op0, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.Y] - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) + assert isinstance(get_stmt_at_idx(test, 6), op.stmts.Y) -def test_x(): +def test_z_y(): @kernel def test(): @@ -523,9 +586,8 @@ def test(): qubit.apply(op0, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.X] - assert filtered_stmts == expected_stmts + assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 6), op.stmts.Y) def test_sdg_y(): @@ -538,6 +600,12 @@ def test(): qubit.apply(op0, q[0]) SquinToCliffordTestPass(test.dialects)(test) - filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) - expected_stmts = [op.stmts.S, op.stmts.Adjoint, op.stmts.Y] - assert filtered_stmts == expected_stmts + + relevant_stmts = filter_statements_by_type( + test, (op.stmts.S, op.stmts.Adjoint, op.stmts.Y) + ) + assert relevant_stmts == [ + op.stmts.S, + op.stmts.Adjoint, + op.stmts.Y, + ] diff --git a/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim b/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim deleted file mode 100644 index 4764563d0..000000000 --- a/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim +++ /dev/null @@ -1,5 +0,0 @@ - -Z 0 -SQRT_X_DAG 0 -SQRT_X_DAG 0 -SQRT_X 0 diff --git a/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim b/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim index 6eb044c65..cb1d4a8cf 100644 --- a/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim +++ b/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim @@ -1,3 +1,12 @@ H 0 +S 0 +SQRT_Y 0 +S 0 +S_DAG 0 +SQRT_Y 0 +S 0 +S 0 +SQRT_Y 0 +S_DAG 0 MZ(0.00000000) 0 diff --git a/test/stim/passes/test_squin_qubit_to_stim.py b/test/stim/passes/test_squin_qubit_to_stim.py index da2dd821a..f91e3a0af 100644 --- a/test/stim/passes/test_squin_qubit_to_stim.py +++ b/test/stim/passes/test_squin_qubit_to_stim.py @@ -138,6 +138,19 @@ def test(): q = qubit.new(n_qubits) # apply U3 rotation that can be translated to a Clifford gate squin.qubit.apply(op.u(0.25 * math.tau, 0.0 * math.tau, 0.5 * math.tau), q[0]) + # S @ SQRT_Y @ S = Z @ SQRT_X + squin.qubit.apply( + op.u(-0.25 * math.tau, -0.25 * math.tau, -0.25 * math.tau), q[0] + ) + # S @ SQRT_Y @ S_DAG = SQRT_X_DAG + squin.qubit.apply( + op.u(-0.25 * math.tau, -0.25 * math.tau, 0.25 * math.tau), q[0] + ) + # S_DAG @ SQRT_Y @ S = SQRT_X + squin.qubit.apply( + op.u(-0.25 * math.tau, 0.25 * math.tau, -0.25 * math.tau), q[0] + ) + # measure out squin.qubit.measure(q) return @@ -174,50 +187,6 @@ def test(): assert codegen(test).strip() == "SQRT_Y 0" -def test_adjoint_gates_rewrite(): - - @squin.kernel - def test(): - q = qubit.new(4) - s_adj = op.adjoint(op.s()) - qubit.apply(s_adj, q[0]) - sqrt_x_adj = op.adjoint(op.sqrt_x()) - qubit.apply(sqrt_x_adj, q[1]) - sqrt_y_adj = op.adjoint(op.sqrt_y()) - qubit.apply(sqrt_y_adj, q[2]) - sqrt_z_adj = op.adjoint(op.sqrt_z()) # same as S_DAG - qubit.apply(sqrt_z_adj, q[3]) - return - - SquinToStimPass(test.dialects)(test) - assert codegen(test).strip() == "S_DAG 0\nSQRT_X_DAG 1\nSQRT_Y_DAG 2\nS_DAG 3" - - -def test_u3_rewrite(): - - @squin.kernel - def test(): - q = qubit.new(1) - u3 = op.u( - -math.pi / 2, -math.pi / 2, -math.pi / 2 - ) # S @ SQRT_Y @ S = SQRT_X_DAG @ Z - qubit.apply(u3, q[0]) - u3 = op.u( - -math.pi / 2, -math.pi / 2, math.pi / 2 - ) # S @ SQRT_Y @ S_DAG = SQRT_X_DAG - qubit.apply(u3, q[0]) - u3 = op.u( - -math.pi / 2, math.pi / 2, -math.pi / 2 - ) # S_DAG @ SQRT_Y @ S = SQRT_X - qubit.apply(u3, q[0]) - - return - - SquinToStimPass(test.dialects)(test) - base_stim_prog = load_reference_program("u3_gates.stim") - assert codegen(test) == base_stim_prog.rstrip() - - def test_for_loop_nontrivial_index_rewrite(): @squin.kernel