@@ -102,9 +102,46 @@ def get_stim_1q_gate(self, squin_op: op.stmts.Operator):
102102 return stim .gate .H
103103 case op .stmts .S ():
104104 return stim .gate .S
105+ case op .stmts .Identity (): # enforce sites defined = num wires in
106+ return stim .gate .Identity
105107 case _:
106108 return None
107109
110+ # get the qubit indices from the Apply statement argument
111+ # wires/qubits
112+ def insert_qubit_idx_ssa (
113+ self , apply_stmt : wire .Apply | qubit .Apply
114+ ) -> tuple [ir .SSAValue , ...]:
115+
116+ if isinstance (apply_stmt , qubit .Apply ):
117+ qubits = apply_stmt .qubits
118+ address_attribute : AddressAttribute = self .get_address (qubits )
119+ # Should get an AddressTuple out of the address stored in attribute
120+ address_tuple = address_attribute .address
121+ qubit_idx_ssas : list [ir .SSAValue ] = []
122+ for address_qubit in address_tuple .data :
123+ qubit_idx = address_qubit .data
124+ qubit_idx_stmt = py .Constant (qubit_idx )
125+ qubit_idx_stmt .insert_before (apply_stmt )
126+ qubit_idx_ssas .append (qubit_idx_stmt .result )
127+
128+ return tuple (qubit_idx_ssas )
129+
130+ elif isinstance (apply_stmt , wire .Apply ):
131+ wire_ssas = apply_stmt .inputs
132+ qubit_idx_ssas : list [ir .SSAValue ] = []
133+ for wire_ssa in wire_ssas :
134+ address_attribute = self .get_address (wire_ssa )
135+ # get parent qubit idx
136+ wire_address = address_attribute .address
137+ qubit_idx = wire_address .origin_qubit .data
138+ qubit_idx_stmt = py .Constant (qubit_idx )
139+ # accumulate all qubit idx SSA to instantiate stim gate stmt
140+ qubit_idx_ssas .append (qubit_idx_stmt .result )
141+ qubit_idx_stmt .insert_before (apply_stmt )
142+
143+ return tuple (qubit_idx_ssas )
144+
108145 # might be worth attempting multiple dispatch like qasm2 rewrites
109146 # for Glob and Parallel to UOp
110147 # The problem is I'd have to introduce names for all the statements
@@ -142,51 +179,60 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult:
142179 # this is an SSAValue, need it to be the actual operator
143180 applied_op = apply_stmt .operator .owner
144181
145- # need to handle Identity and Control through separate means
146- # but we can handle X, Y, Z, and H here just fine
147- stim_1q_op = self .get_stim_1q_gate (applied_op )
148-
149- if isinstance (apply_stmt , qubit .Apply ):
150- qubits = apply_stmt .qubits
151- address_attribute : AddressAttribute = self .get_address (qubits )
152- # Should get an AddressTuple out of the address stored in attribute
153- address_tuple = address_attribute .address
154- qubit_idx_ssas : list [ir .SSAValue ] = []
155- for address_qubit in address_tuple .data :
156- qubit_idx = address_qubit .data
157- qubit_idx_stmt = py .Constant (qubit_idx )
158- qubit_idx_ssas .append (qubit_idx_stmt .result )
159- qubit_idx_stmt .insert_before (apply_stmt )
160-
161- stim_1q_stmt = stim_1q_op (targets = tuple (qubit_idx_ssas ))
162-
163- # can't do any of this because of dependencies downstream
164- # apply_stmt.replace_by(stim_1q_stmt)
165-
166- return RewriteResult (has_done_something = True )
182+ if isinstance (applied_op , op .stmts .Control ):
183+ return self .rewrite_Control (apply_stmt )
167184
168- elif isinstance (apply_stmt , wire .Apply ):
169- wires_ssa = apply_stmt .inputs
170- qubit_idx_ssas : list [ir .SSAValue ] = []
171- for wire_ssa in wires_ssa :
172- address_attribute = self .get_address (wire_ssa )
173- # get parent qubit idx
174- wire_address = address_attribute .address
175- qubit_idx = wire_address .origin_qubit .data
176- qubit_idx_stmt = py .Constant (qubit_idx )
177- # accumulate all qubit idx SSA to instantiate stim gate stmt
178- qubit_idx_ssas .append (qubit_idx_stmt .result )
179- qubit_idx_stmt .insert_before (apply_stmt )
185+ # need to handle Control through separate means
186+ # but we can handle X, Y, Z, H, and S here just fine
187+ stim_1q_op = self .get_stim_1q_gate (applied_op )
180188
181- stim_1q_stmt = stim_1q_op (targets = tuple (qubit_idx_ssas ))
182- stim_1q_stmt .insert_before (apply_stmt )
189+ qubit_idx_ssas = self .insert_qubit_idx_ssa (apply_stmt = apply_stmt )
190+ stim_1q_stmt = stim_1q_op (targets = tuple (qubit_idx_ssas ))
191+ stim_1q_stmt .insert_before (apply_stmt )
183192
184- # There is something depending on the results of the statement,
185- # need to handle that so replacement/deletion can occur without problems
193+ return RewriteResult (has_done_something = True )
186194
187- # apply's results become wires that go to other apply's/wrap stmts
188- # apply_stmt.replace_by(stim_1q_stmt)
195+ def rewrite_Control (
196+ self , apply_stmt_ctrl : qubit .Apply | wire .Apply
197+ ) -> RewriteResult :
198+ # stim only supports CX, CY, CZ so we have to check the
199+ # operator of Apply is a Control gate, enforce it's only asking for 1 control qubit,
200+ # and that the target of the control is X, Y, Z in squin
201+
202+ ctrl_op : op .stmts .Control = apply_stmt_ctrl .operator .owner
203+ # enforce that n_controls is 1
204+
205+ ctrl_op_target_gate = ctrl_op .op .owner
206+
207+ # should enforce that this is some multiple of 2
208+ qubit_idx_ssas = self .insert_qubit_idx_ssa (apply_stmt = apply_stmt_ctrl )
209+ # according to stim, final result can be:
210+ # CX 1 2 3 4 -> CX(1, targ=2), CX(3, targ=4)
211+ target_qubits = []
212+ ctrl_qubits = []
213+ # definitely a better way to do this but
214+ # can't think of it right now
215+ for i in range (len (qubit_idx_ssas )):
216+ if (i % 2 ) == 0 :
217+ ctrl_qubits .append (qubit_idx_ssas [i ])
218+ else :
219+ target_qubits .append (qubit_idx_ssas [i ])
220+
221+ target_qubits = tuple (target_qubits )
222+ ctrl_qubits = tuple (ctrl_qubits )
223+
224+ match ctrl_op_target_gate :
225+ case op .stmts .X ():
226+ stim_stmt = stim .CX (controls = ctrl_qubits , targets = target_qubits )
227+ case op .stmts .Y ():
228+ stim_stmt = stim .CY (controls = ctrl_qubits , targets = target_qubits )
229+ case op .stmts .Z ():
230+ stim_stmt = stim .CZ (controls = ctrl_qubits , targets = target_qubits )
231+ case _:
232+ raise NotImplementedError (
233+ "Control gates beyond CX, CY, and CZ are not supported"
234+ )
189235
190- return RewriteResult ( has_done_something = True )
236+ stim_stmt . insert_before ( apply_stmt_ctrl )
191237
192- return RewriteResult ()
238+ return RewriteResult (has_done_something = True )
0 commit comments