Skip to content

Commit d515298

Browse files
authored
Rafaelha/adjoint gates squin to sim (#559)
1 parent 1e347d8 commit d515298

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

src/bloqade/stim/rewrite/qubit_to_stim.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,13 @@ def rewrite_SingleQubitGate(
8080
stim_stmt_cls = getattr(stim_gate.stmts, stmt_name, None)
8181
if stim_stmt_cls is None:
8282
return RewriteResult()
83-
stim_stmt = stim_stmt_cls(targets=tuple(qubit_idx_ssas))
83+
84+
if isinstance(stmt, gate.stmts.SingleQubitNonHermitianGate):
85+
stim_stmt = stim_stmt_cls(
86+
targets=tuple(qubit_idx_ssas), dagger=stmt.adjoint
87+
)
88+
else:
89+
stim_stmt = stim_stmt_cls(targets=tuple(qubit_idx_ssas))
8490
stmt.replace_by(stim_stmt)
8591

8692
return RewriteResult(has_done_something=True)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
2+
S 0
3+
SQRT_Y 0
4+
S 0
5+
S_DAG 0
6+
SQRT_Y 0
7+
S 0
8+
S 0
9+
SQRT_Y 0
10+
S_DAG 0

test/stim/passes/test_squin_qubit_to_stim.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import math
3+
from math import pi
34

45
from kirin import ir
56
from kirin.dialects import py
@@ -151,6 +152,37 @@ def test():
151152
assert codegen(test).strip() == "SQRT_Y 0"
152153

153154

155+
def test_adjoint_gates_rewrite():
156+
157+
@sq.kernel
158+
def test():
159+
q = sq.qalloc(4)
160+
sq.s_adj(q[0])
161+
sq.sqrt_x_adj(q[1])
162+
sq.sqrt_y_adj(q[2])
163+
sq.sqrt_z_adj(q[3]) # same as S_DAG
164+
return
165+
166+
SquinToStimPass(test.dialects)(test)
167+
assert codegen(test).strip() == "S_DAG 0\nSQRT_X_DAG 1\nSQRT_Y_DAG 2\nS_DAG 3"
168+
169+
170+
def test_u3_rewrite():
171+
172+
@sq.kernel
173+
def test():
174+
q = sq.qalloc(1)
175+
176+
sq.u3(-pi / 2, -pi / 2, -pi / 2, q[0]) # S @ SQRT_Y @ S = Z @ SQRT_X
177+
sq.u3(-pi / 2, -pi / 2, pi / 2, q[0]) # S @ SQRT_Y @ S_DAG = SQRT_X_DAG
178+
sq.u3(-pi / 2, pi / 2, -pi / 2, q[0]) # S_DAG @ SQRT_Y @ S = SQRT_X
179+
return
180+
181+
SquinToStimPass(test.dialects)(test)
182+
base_stim_prog = load_reference_program("u3_gates.stim")
183+
assert codegen(test) == base_stim_prog.rstrip()
184+
185+
154186
def test_for_loop_nontrivial_index_rewrite():
155187

156188
@sq.kernel

0 commit comments

Comments
 (0)