Skip to content

Commit 170088a

Browse files
authored
Fix missing adjoint problem (#562)
1 parent 4cef564 commit 170088a

File tree

5 files changed

+62
-13
lines changed

5 files changed

+62
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ requires-python = ">=3.10"
1313
dependencies = [
1414
"numpy>=1.22.0",
1515
"scipy>=1.13.1",
16-
"kirin-toolchain~=0.17.23",
16+
"kirin-toolchain~=0.17.26",
1717
"rich>=13.9.4",
1818
"pydantic>=1.3.0,<2.11.0",
1919
"pandas>=2.2.3",

src/bloqade/stim/rewrite/qubit_to_stim.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from bloqade.squin import op, noise, qubit
55
from bloqade.squin.rewrite import AddressAttribute
6-
from bloqade.stim.dialects import gate
76
from bloqade.stim.rewrite.util import (
87
SQUIN_STIM_OP_MAPPING,
98
rewrite_Control,
@@ -40,20 +39,24 @@ def rewrite_Apply_and_Broadcast(
4039

4140
assert isinstance(applied_op, op.stmts.Operator)
4241

42+
# Handle controlled gates with a separate procedure
4343
if isinstance(applied_op, op.stmts.Control):
4444
return rewrite_Control(stmt)
4545

46-
# need to handle Control through separate means
47-
4846
# check if its adjoint, assume its canonicalized so no nested adjoints.
4947
is_conj = False
5048
if isinstance(applied_op, op.stmts.Adjoint):
51-
if not applied_op.is_unitary:
49+
# By default the Adjoint has is_unitary = False, so we need to check
50+
# the inner applied operator to make sure its not just unitary,
51+
# but something that has an equivalent stim representation with *_DAG format.
52+
if isinstance(
53+
applied_op.op.owner, (op.stmts.SqrtX, op.stmts.SqrtY, op.stmts.S)
54+
):
55+
is_conj = True
56+
applied_op = applied_op.op.owner
57+
else:
5258
return RewriteResult()
5359

54-
is_conj = True
55-
applied_op = applied_op.op.owner
56-
5760
stim_1q_op = SQUIN_STIM_OP_MAPPING.get(type(applied_op))
5861
if stim_1q_op is None:
5962
return RewriteResult()
@@ -71,13 +74,14 @@ def rewrite_Apply_and_Broadcast(
7174
if qubit_idx_ssas is None:
7275
return RewriteResult()
7376

74-
if isinstance(stim_1q_op, gate.stmts.Gate):
77+
# At this point, we know for certain stim_1q_op must be SQRT_X, SQRT_Y, or S
78+
# and has the option to set the dagger attribute. If is_conj is false,
79+
# the rewrite would have terminated early so we know anything else has to be
80+
# a non 1Q gate operation.
81+
if is_conj:
7582
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas), dagger=is_conj)
7683
else:
7784
stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas))
7885
stmt.replace_by(stim_1q_stmt)
7986

8087
return RewriteResult(has_done_something=True)
81-
82-
83-
# put rewrites for measure statements in separate rule, then just have to dispatch
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
SQRT_X_DAG 0
3+
SQRT_Y_DAG 0
4+
S_DAG 0
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
11

22
H 0
3+
S 0
4+
SQRT_Y 0
5+
S 0
6+
S_DAG 0
7+
SQRT_Y 0
8+
S 0
9+
S 0
10+
SQRT_Y 0
11+
S_DAG 0
312
MZ(0.00000000) 0

test/stim/passes/test_squin_qubit_to_stim.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,26 @@ def test():
110110
assert codegen(test) == base_stim_prog.rstrip()
111111

112112

113+
def test_adjoint_rewrite():
114+
115+
@squin.kernel
116+
def test():
117+
q = qubit.new(1)
118+
sqrt_x_dag = op.adjoint(op.sqrt_x())
119+
sqrt_y_dag = op.adjoint(op.sqrt_y())
120+
sqrt_s_dag = op.adjoint(op.s())
121+
qubit.apply(sqrt_x_dag, q[0])
122+
qubit.apply(sqrt_y_dag, q[0])
123+
qubit.apply(sqrt_s_dag, q[0])
124+
return
125+
126+
SquinToStimPass(test.dialects)(test)
127+
128+
base_stim_prog = load_reference_program("adjoint_rewrite.stim")
129+
130+
assert codegen(test) == base_stim_prog.rstrip()
131+
132+
113133
def test_u3_to_clifford():
114134

115135
@kernel
@@ -118,12 +138,24 @@ def test():
118138
q = qubit.new(n_qubits)
119139
# apply U3 rotation that can be translated to a Clifford gate
120140
squin.qubit.apply(op.u(0.25 * math.tau, 0.0 * math.tau, 0.5 * math.tau), q[0])
141+
# S @ SQRT_Y @ S = Z @ SQRT_X
142+
squin.qubit.apply(
143+
op.u(-0.25 * math.tau, -0.25 * math.tau, -0.25 * math.tau), q[0]
144+
)
145+
# S @ SQRT_Y @ S_DAG = SQRT_X_DAG
146+
squin.qubit.apply(
147+
op.u(-0.25 * math.tau, -0.25 * math.tau, 0.25 * math.tau), q[0]
148+
)
149+
# S_DAG @ SQRT_Y @ S = SQRT_X
150+
squin.qubit.apply(
151+
op.u(-0.25 * math.tau, 0.25 * math.tau, -0.25 * math.tau), q[0]
152+
)
153+
121154
# measure out
122155
squin.qubit.measure(q)
123156
return
124157

125158
SquinToStimPass(test.dialects)(test)
126-
127159
base_stim_prog = load_reference_program("u3_to_clifford.stim")
128160

129161
assert codegen(test) == base_stim_prog.rstrip()

0 commit comments

Comments
 (0)