diff --git a/src/bloqade/squin/rewrite/U3_to_clifford.py b/src/bloqade/squin/rewrite/U3_to_clifford.py index bc3c4f63b..e353594b8 100644 --- a/src/bloqade/squin/rewrite/U3_to_clifford.py +++ b/src/bloqade/squin/rewrite/U3_to_clifford.py @@ -14,6 +14,20 @@ def sdag() -> list[ir.Statement]: return [_op := op.stmts.S(), op.stmts.Adjoint(op=_op.result, is_unitary=True)] +def sqrt_x_dag() -> list[ir.Statement]: + return [ + _op := op.stmts.SqrtX(), + op.stmts.Adjoint(op=_op.result, is_unitary=True), + ] + + +def sqrt_y_dag() -> list[ir.Statement]: + return [ + _op := op.stmts.SqrtY(), + op.stmts.Adjoint(op=_op.result, is_unitary=True), + ] + + # (theta, phi, lam) U3_HALF_PI_ANGLE_TO_GATES: dict[ tuple[int, int, int], Callable[[], Tuple[List[ir.Statement], ...]] @@ -26,21 +40,21 @@ def sdag() -> list[ir.Statement]: (1, 0, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()]), (1, 0, 2): lambda: ([op.stmts.H()],), (1, 0, 3): lambda: (sdag(), [op.stmts.SqrtY()]), - (1, 1, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.S()]), - (1, 1, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]), - (1, 1, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]), - (1, 1, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]), - (1, 2, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.Z()]), - (1, 2, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]), - (1, 2, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]), - (1, 2, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]), - (1, 3, 0): lambda: ([op.stmts.SqrtY()], sdag()), - (1, 3, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()), - (1, 3, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()), - (1, 3, 3): lambda: (sdag(), [op.stmts.SqrtY()], sdag()), + (1, 1, 0): lambda: ([op.stmts.S()], sqrt_x_dag()), + (1, 1, 1): lambda: ([op.stmts.Z()], sqrt_x_dag()), + (1, 1, 2): lambda: (sdag(), sqrt_x_dag()), + (1, 1, 3): lambda: (sqrt_x_dag(),), + (1, 2, 0): lambda: ([op.stmts.Z()], sqrt_y_dag()), + (1, 2, 1): lambda: (sdag(), sqrt_y_dag()), + (1, 2, 2): lambda: (sqrt_y_dag(),), + (1, 2, 3): lambda: ([op.stmts.S()], sqrt_y_dag()), + (1, 3, 0): lambda: (sdag(), [op.stmts.SqrtX()]), + (1, 3, 1): lambda: ([op.stmts.SqrtX()],), + (1, 3, 2): lambda: ([op.stmts.S()], [op.stmts.SqrtX()]), + (1, 3, 3): lambda: ([op.stmts.Z()], [op.stmts.SqrtX()]), (2, 0, 0): lambda: ([op.stmts.Y()],), (2, 0, 1): lambda: ([op.stmts.S()], [op.stmts.Y()]), - (2, 0, 2): lambda: ([op.stmts.Z()], [op.stmts.Y()]), + (2, 0, 2): lambda: ([op.stmts.X()],), (2, 0, 3): lambda: (sdag(), [op.stmts.Y()]), } diff --git a/test/squin/rewrite/test_U3_to_clifford.py b/test/squin/rewrite/test_U3_to_clifford.py index 6076c1cf1..06030120b 100644 --- a/test/squin/rewrite/test_U3_to_clifford.py +++ b/test/squin/rewrite/test_U3_to_clifford.py @@ -17,10 +17,6 @@ def unsafe_run(self, mt: ir.Method): ).rewrite(mt.code) -def get_stmt_at_idx(method: ir.Method, idx: int) -> ir.Statement: - return method.callable_region.blocks[0].stmts.at(idx) - - def filter_statements_by_type(method: ir.Method, types: tuple[type, ...]) -> list[type]: return [ type(stmt) @@ -39,8 +35,9 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Identity) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.Identity] + assert filtered_stmts == expected_stmts def test_s(): @@ -52,7 +49,9 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S] + assert filtered_stmts == expected_stmts # exercise equivalent_u3_para check ## assumes it's already in units of half pi and normalized to [0, 1) @@ -63,8 +62,9 @@ def test_equiv(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test_equiv.dialects)(test_equiv) - - assert isinstance(get_stmt_at_idx(test_equiv, 4), op.stmts.S) + filtered_stmts = filter_statements_by_type(test_equiv, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S] + assert filtered_stmts == expected_stmts def test_s_alternative(): @@ -76,7 +76,9 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S] + assert filtered_stmts == expected_stmts def test_z(): @@ -95,10 +97,9 @@ def test(): qubit.apply(op2, q[2]) SquinToCliffordTestPass(test.dialects)(test) - - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Z) - assert isinstance(get_stmt_at_idx(test, 8), op.stmts.Z) - assert isinstance(get_stmt_at_idx(test, 12), op.stmts.Z) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.Z] * 3 + assert filtered_stmts == expected_stmts def test_z_alternative(): @@ -110,7 +111,9 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Z) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.Z] + assert filtered_stmts == expected_stmts def test_sdag(): @@ -122,8 +125,9 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.Adjoint] + assert filtered_stmts == expected_stmts @kernel def test_equiv(): @@ -132,8 +136,9 @@ def test_equiv(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test_equiv.dialects)(test_equiv) - assert isinstance(get_stmt_at_idx(test_equiv, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test_equiv, 5), op.stmts.Adjoint) + filtered_stmts = filter_statements_by_type(test_equiv, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.Adjoint] + assert filtered_stmts == expected_stmts def test_sdag_alternative_negative(): @@ -145,8 +150,9 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.Adjoint] + assert filtered_stmts == expected_stmts def test_sdag_alternative(): @@ -158,8 +164,9 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.Adjoint] + assert filtered_stmts == expected_stmts def test_sdag_weird_case(): @@ -171,8 +178,9 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.Adjoint] + assert filtered_stmts == expected_stmts def test_sdag_weirder_case(): @@ -184,8 +192,11 @@ def test(): qubit.apply(oper, q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.Adjoint] + # Technically a Y afterwards, just want to check the first two + # stmts are S + Adjoint + assert filtered_stmts[:-1] == expected_stmts def test_sqrt_y(): @@ -201,9 +212,9 @@ def test(): qubit.apply(op1, q[0]) SquinToCliffordTestPass(test.dialects)(test) - - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 8), op.stmts.SqrtY) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.SqrtY] * 2 + assert filtered_stmts == expected_stmts def test_s_sqrt_y(): @@ -219,10 +230,9 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 6), op.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 10), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 12), op.stmts.SqrtY) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.SqrtY, op.stmts.S, op.stmts.SqrtY] + assert filtered_stmts == expected_stmts def test_h(): @@ -237,8 +247,9 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.H) - assert isinstance(get_stmt_at_idx(test, 8), op.stmts.H) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.H] * 2 + assert filtered_stmts == expected_stmts def test_sdg_sqrt_y(): @@ -254,15 +265,12 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) - assert isinstance(get_stmt_at_idx(test, 7), op.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 11), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 12), op.stmts.Adjoint) - assert isinstance(get_stmt_at_idx(test, 14), op.stmts.SqrtY) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.Adjoint, op.stmts.SqrtY] * 2 + assert filtered_stmts == expected_stmts -def test_sqrt_y_s(): +def test_s_sqrt_x_dag(): @kernel def test(): @@ -274,10 +282,12 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.SqrtY) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.SqrtX, op.stmts.Adjoint] * 2 + assert filtered_stmts == expected_stmts -def test_s_sqrt_y_s(): +def test_z_sqrt_x_dag(): @kernel def test(): @@ -290,19 +300,12 @@ def test(): SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type(test, (op.stmts.S, op.stmts.SqrtY)) - - assert relevant_stmts == [ - op.stmts.S, - op.stmts.SqrtY, - op.stmts.S, - op.stmts.S, - op.stmts.SqrtY, - op.stmts.S, - ] + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.Z, op.stmts.SqrtX, op.stmts.Adjoint] * 2 + assert filtered_stmts == expected_stmts -def test_z_sqrt_y_s(): +def test_s_dag_sqrt_x_dag(): @kernel def test(): @@ -314,21 +317,17 @@ def test(): qubit.apply(op1, q[0]) SquinToCliffordTestPass(test.dialects)(test) - - relevant_stmts = filter_statements_by_type( - test, (op.stmts.Z, op.stmts.SqrtY, op.stmts.S) - ) - assert relevant_stmts == [ - op.stmts.Z, - op.stmts.SqrtY, + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [ op.stmts.S, - op.stmts.Z, - op.stmts.SqrtY, - op.stmts.S, - ] + op.stmts.Adjoint, + op.stmts.SqrtX, + op.stmts.Adjoint, + ] * 2 + assert filtered_stmts == expected_stmts -def test_sdg_sqrt_y_s(): +def test_sqrt_x_dag(): @kernel def test(): @@ -341,24 +340,12 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.SqrtX, op.stmts.Adjoint] * 2 + assert filtered_stmts == expected_stmts - relevant_stmts = filter_statements_by_type( - test, (op.stmts.S, op.stmts.Adjoint, op.stmts.SqrtY) - ) - - assert relevant_stmts == [ - op.stmts.S, - op.stmts.Adjoint, - op.stmts.SqrtY, - op.stmts.S, - op.stmts.S, - op.stmts.Adjoint, - op.stmts.SqrtY, - op.stmts.S, - ] - -def test_sqrt_y_z(): +def test_z_sqrt_y_dag(): @kernel def test(): @@ -371,14 +358,12 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 6), op.stmts.Z) - assert isinstance(get_stmt_at_idx(test, 10), op.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 12), op.stmts.Z) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.Z, op.stmts.SqrtY, op.stmts.Adjoint] * 2 + assert filtered_stmts == expected_stmts -def test_s_sqrt_y_z(): +def test_s_dag_sqrt_y_dag(): @kernel def test(): @@ -391,22 +376,17 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - - relevant_stmts = filter_statements_by_type( - test, (op.stmts.S, op.stmts.SqrtY, op.stmts.Z) - ) - - assert relevant_stmts == [ - op.stmts.S, - op.stmts.SqrtY, - op.stmts.Z, + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [ op.stmts.S, + op.stmts.Adjoint, op.stmts.SqrtY, - op.stmts.Z, - ] + op.stmts.Adjoint, + ] * 2 + assert filtered_stmts == expected_stmts -def test_z_sqrt_y_z(): +def test_sqrt_y_dag(): @kernel def test(): @@ -420,19 +400,12 @@ def test(): SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type(test, (op.stmts.Z, op.stmts.SqrtY)) - - assert relevant_stmts == [ - op.stmts.Z, - op.stmts.SqrtY, - op.stmts.Z, - op.stmts.Z, - op.stmts.SqrtY, - op.stmts.Z, - ] + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.SqrtY, op.stmts.Adjoint] * 2 + assert filtered_stmts == expected_stmts -def test_sdg_sqrt_y_z(): +def test_s_sqrt_y_dag(): @kernel def test(): @@ -445,88 +418,57 @@ def test(): qubit.apply(op1, q[1]) SquinToCliffordTestPass(test.dialects)(test) - - relevant_stmts = filter_statements_by_type( - test, (op.stmts.S, op.stmts.Adjoint, op.stmts.SqrtY, op.stmts.Z) - ) - assert relevant_stmts == [ - op.stmts.S, - op.stmts.Adjoint, - op.stmts.SqrtY, - op.stmts.Z, - op.stmts.S, - op.stmts.Adjoint, - op.stmts.SqrtY, - op.stmts.Z, - ] + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.SqrtY, op.stmts.Adjoint] * 2 + assert filtered_stmts == expected_stmts -def test_sqrt_y_sdg(): +def test_s_dag_sqrt_x(): @kernel def test(): q = qubit.new(4) # (1, 3, 0) - op0 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.0 * math.tau) - qubit.apply(op0, q[0]) + u3 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.0 * math.tau) + qubit.apply(u3, q[0]) SquinToCliffordTestPass(test.dialects)(test) - - relevant_stmts = filter_statements_by_type( - test, (op.stmts.SqrtY, op.stmts.S, op.stmts.Adjoint) - ) - assert relevant_stmts == [ - op.stmts.SqrtY, - op.stmts.S, - op.stmts.Adjoint, - ] + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.Adjoint, op.stmts.SqrtX] + assert filtered_stmts == expected_stmts -def test_s_sqrt_y_sdg(): +def test_sqrt_x(): @kernel def test(): q = qubit.new(4) # (1, 3, 1) - op0 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.25 * math.tau) - qubit.apply(op0, q[0]) + u3 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.25 * math.tau) + qubit.apply(u3, q[0]) SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type( - test, (op.stmts.SqrtY, op.stmts.S, op.stmts.Adjoint) - ) - - assert relevant_stmts == [ - op.stmts.S, - op.stmts.SqrtY, - op.stmts.S, - op.stmts.Adjoint, - ] + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.SqrtX] + assert filtered_stmts == expected_stmts -def test_z_sqrt_y_sdg(): +def test_s_sqrt_x(): @kernel def test(): q = qubit.new(4) # (1, 3, 2) - op0 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.5 * math.tau) - qubit.apply(op0, q[0]) + u3 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.5 * math.tau) + qubit.apply(u3, q[0]) SquinToCliffordTestPass(test.dialects)(test) - - relevant_stmts = filter_statements_by_type( - test, (op.stmts.Z, op.stmts.SqrtY, op.stmts.S, op.stmts.Adjoint) - ) - assert relevant_stmts == [ - op.stmts.Z, - op.stmts.SqrtY, - op.stmts.S, - op.stmts.Adjoint, - ] + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.SqrtX] + assert filtered_stmts == expected_stmts -def test_sdg_sqrt_y_sdg(): +def test_z_sqrt_x(): @kernel def test(): @@ -536,17 +478,9 @@ def test(): qubit.apply(op0, q[0]) SquinToCliffordTestPass(test.dialects)(test) - - relevant_stmts = filter_statements_by_type( - test, (op.stmts.S, op.stmts.Adjoint, op.stmts.SqrtY) - ) - assert relevant_stmts == [ - op.stmts.S, - op.stmts.Adjoint, - op.stmts.SqrtY, - op.stmts.S, - op.stmts.Adjoint, - ] + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.Z, op.stmts.SqrtX] + assert filtered_stmts == expected_stmts def test_y(): @@ -559,7 +493,9 @@ def test(): qubit.apply(op0, q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Y) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.Y] + assert filtered_stmts == expected_stmts def test_s_y(): @@ -572,11 +508,12 @@ def test(): qubit.apply(op0, q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 6), op.stmts.Y) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.Y] + assert filtered_stmts == expected_stmts -def test_z_y(): +def test_x(): @kernel def test(): @@ -586,8 +523,9 @@ def test(): qubit.apply(op0, q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Z) - assert isinstance(get_stmt_at_idx(test, 6), op.stmts.Y) + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.X] + assert filtered_stmts == expected_stmts def test_sdg_y(): @@ -600,12 +538,6 @@ def test(): qubit.apply(op0, q[0]) SquinToCliffordTestPass(test.dialects)(test) - - relevant_stmts = filter_statements_by_type( - test, (op.stmts.S, op.stmts.Adjoint, op.stmts.Y) - ) - assert relevant_stmts == [ - op.stmts.S, - op.stmts.Adjoint, - op.stmts.Y, - ] + filtered_stmts = filter_statements_by_type(test, (op.stmts.Operator,)) + expected_stmts = [op.stmts.S, op.stmts.Adjoint, op.stmts.Y] + assert filtered_stmts == expected_stmts diff --git a/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim b/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim new file mode 100644 index 000000000..4764563d0 --- /dev/null +++ b/test/stim/passes/stim_reference_programs/qubit/u3_gates.stim @@ -0,0 +1,5 @@ + +Z 0 +SQRT_X_DAG 0 +SQRT_X_DAG 0 +SQRT_X 0 diff --git a/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim b/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim index cb1d4a8cf..6eb044c65 100644 --- a/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim +++ b/test/stim/passes/stim_reference_programs/qubit/u3_to_clifford.stim @@ -1,12 +1,3 @@ H 0 -S 0 -SQRT_Y 0 -S 0 -S_DAG 0 -SQRT_Y 0 -S 0 -S 0 -SQRT_Y 0 -S_DAG 0 MZ(0.00000000) 0 diff --git a/test/stim/passes/test_squin_qubit_to_stim.py b/test/stim/passes/test_squin_qubit_to_stim.py index f91e3a0af..da2dd821a 100644 --- a/test/stim/passes/test_squin_qubit_to_stim.py +++ b/test/stim/passes/test_squin_qubit_to_stim.py @@ -138,19 +138,6 @@ def test(): q = qubit.new(n_qubits) # apply U3 rotation that can be translated to a Clifford gate squin.qubit.apply(op.u(0.25 * math.tau, 0.0 * math.tau, 0.5 * math.tau), q[0]) - # S @ SQRT_Y @ S = Z @ SQRT_X - squin.qubit.apply( - op.u(-0.25 * math.tau, -0.25 * math.tau, -0.25 * math.tau), q[0] - ) - # S @ SQRT_Y @ S_DAG = SQRT_X_DAG - squin.qubit.apply( - op.u(-0.25 * math.tau, -0.25 * math.tau, 0.25 * math.tau), q[0] - ) - # S_DAG @ SQRT_Y @ S = SQRT_X - squin.qubit.apply( - op.u(-0.25 * math.tau, 0.25 * math.tau, -0.25 * math.tau), q[0] - ) - # measure out squin.qubit.measure(q) return @@ -187,6 +174,50 @@ def test(): assert codegen(test).strip() == "SQRT_Y 0" +def test_adjoint_gates_rewrite(): + + @squin.kernel + def test(): + q = qubit.new(4) + s_adj = op.adjoint(op.s()) + qubit.apply(s_adj, q[0]) + sqrt_x_adj = op.adjoint(op.sqrt_x()) + qubit.apply(sqrt_x_adj, q[1]) + sqrt_y_adj = op.adjoint(op.sqrt_y()) + qubit.apply(sqrt_y_adj, q[2]) + sqrt_z_adj = op.adjoint(op.sqrt_z()) # same as S_DAG + qubit.apply(sqrt_z_adj, q[3]) + return + + SquinToStimPass(test.dialects)(test) + assert codegen(test).strip() == "S_DAG 0\nSQRT_X_DAG 1\nSQRT_Y_DAG 2\nS_DAG 3" + + +def test_u3_rewrite(): + + @squin.kernel + def test(): + q = qubit.new(1) + u3 = op.u( + -math.pi / 2, -math.pi / 2, -math.pi / 2 + ) # S @ SQRT_Y @ S = SQRT_X_DAG @ Z + qubit.apply(u3, q[0]) + u3 = op.u( + -math.pi / 2, -math.pi / 2, math.pi / 2 + ) # S @ SQRT_Y @ S_DAG = SQRT_X_DAG + qubit.apply(u3, q[0]) + u3 = op.u( + -math.pi / 2, math.pi / 2, -math.pi / 2 + ) # S_DAG @ SQRT_Y @ S = SQRT_X + qubit.apply(u3, q[0]) + + return + + SquinToStimPass(test.dialects)(test) + base_stim_prog = load_reference_program("u3_gates.stim") + assert codegen(test) == base_stim_prog.rstrip() + + def test_for_loop_nontrivial_index_rewrite(): @squin.kernel