88
99from bloqade import stim
1010from bloqade .squin import op , wire , qubit
11- from bloqade .analysis .address import Address , AddressWire , AddressTuple
11+ from bloqade .analysis .address import Address , AddressWire , AddressQubit , AddressTuple
1212from bloqade .squin .analysis .nsites import Sites
1313
1414# Probably best to move these attributes to a
@@ -82,9 +82,12 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
8282@dataclass
8383class _SquinToStim (RewriteRule ):
8484
85- def get_address (self , value : ir .SSAValue ):
85+ def get_address (self , value : ir .SSAValue ) -> AddressAttribute :
86+
8687 try :
87- return value .hints ["address" ]
88+ address_attr = value .hints ["address" ]
89+ assert isinstance (address_attr , AddressAttribute )
90+ return address_attr
8891 except KeyError :
8992 raise KeyError (f"The address analysis hint for { value } does not exist" )
9093
@@ -115,40 +118,80 @@ def get_stim_1q_gate(self, squin_op: op.stmts.Operator):
115118 f"The squin operator { squin_op } is not supported in the stim dialect"
116119 )
117120
121+ def insert_qubit_idx_from_address (
122+ self , address : AddressAttribute , stmt_to_insert_before : ir .Statement
123+ ) -> tuple [ir .SSAValue , ...]:
124+
125+ address_data = address .address
126+
127+ qubit_idx_ssas = []
128+
129+ if isinstance (address_data , AddressTuple ):
130+ for address_qubit in address_data .data :
131+
132+ # ensure that the stuff in the AddressTuple should be AddressQubit
133+ # could handle AddressWires as well but don't see the need for that right now
134+ if not isinstance (address_qubit , AddressQubit ):
135+ raise ValueError (
136+ "Unsupported Address type detected inside AddressTuple, must be AddressQubit"
137+ )
138+ qubit_idx = address_qubit .data
139+ qubit_idx_stmt = py .Constant (qubit_idx )
140+ qubit_idx_stmt .insert_before (stmt_to_insert_before )
141+ qubit_idx_ssas .append (qubit_idx_stmt .result )
142+ elif isinstance (address_data , AddressWire ):
143+ address_qubit = address_data .origin_qubit
144+ qubit_idx = address_qubit .data
145+ qubit_idx_stmt = py .Constant (qubit_idx )
146+ qubit_idx_stmt .insert_before (stmt_to_insert_before )
147+ qubit_idx_ssas .append (qubit_idx_stmt .result )
148+ else :
149+ NotImplementedError (
150+ "qubit idx extraction and insertion only support for AddressTuple[AddressQubit] and AddressWire instances"
151+ )
152+
153+ return tuple (qubit_idx_ssas )
154+
155+ def insert_qubit_idx_from_wire_ssa (
156+ self , wire_ssas : tuple [ir .SSAValue , ...], stmt_to_insert_before : ir .Statement
157+ ) -> tuple [ir .SSAValue , ...]:
158+ qubit_idx_ssas = []
159+ for wire_ssa in wire_ssas :
160+ address_attribute = self .get_address (wire_ssa ) # get AddressWire
161+ # get parent qubit idx
162+ wire_address = address_attribute .address
163+ assert isinstance (wire_address , AddressWire )
164+ qubit_idx = wire_address .origin_qubit .data
165+ qubit_idx_stmt = py .Constant (qubit_idx )
166+ # accumulate all qubit idx SSA to instantiate stim gate stmt
167+ qubit_idx_ssas .append (qubit_idx_stmt .result )
168+ qubit_idx_stmt .insert_before (stmt_to_insert_before )
169+
170+ return tuple (qubit_idx_ssas )
171+
118172 # get the qubit indices from the Apply statement argument
119173 # wires/qubits
120- def insert_qubit_idx_ssa (
174+
175+ def insert_qubit_idx_after_apply (
121176 self , apply_stmt : wire .Apply | qubit .Apply
122177 ) -> tuple [ir .SSAValue , ...]:
123178
124179 if isinstance (apply_stmt , qubit .Apply ):
125180 qubits = apply_stmt .qubits
126181 address_attribute : AddressAttribute = self .get_address (qubits )
127182 # Should get an AddressTuple out of the address stored in attribute
128- address_tuple = address_attribute .address
129- qubit_idx_ssas : list [ir .SSAValue ] = []
130- for address_qubit in address_tuple .data :
131- qubit_idx = address_qubit .data
132- qubit_idx_stmt = py .Constant (qubit_idx )
133- qubit_idx_stmt .insert_before (apply_stmt )
134- qubit_idx_ssas .append (qubit_idx_stmt .result )
135-
136- return tuple (qubit_idx_ssas )
137-
183+ return self .insert_qubit_idx_from_address (
184+ address = address_attribute , stmt_to_insert_before = apply_stmt
185+ )
138186 elif isinstance (apply_stmt , wire .Apply ):
139187 wire_ssas = apply_stmt .inputs
140- qubit_idx_ssas : list [ir .SSAValue ] = []
141- for wire_ssa in wire_ssas :
142- address_attribute = self .get_address (wire_ssa )
143- # get parent qubit idx
144- wire_address = address_attribute .address
145- qubit_idx = wire_address .origin_qubit .data
146- qubit_idx_stmt = py .Constant (qubit_idx )
147- # accumulate all qubit idx SSA to instantiate stim gate stmt
148- qubit_idx_ssas .append (qubit_idx_stmt .result )
149- qubit_idx_stmt .insert_before (apply_stmt )
150-
151- return tuple (qubit_idx_ssas )
188+ return self .insert_qubit_idx_from_wire_ssa (
189+ wire_ssas = wire_ssas , stmt_to_insert_before = apply_stmt
190+ )
191+ else :
192+ raise TypeError (
193+ "unsupported statement detected, only wire.Apply and qubit.Apply statements are supported by this method"
194+ )
152195
153196 # might be worth attempting multiple dispatch like qasm2 rewrites
154197 # for Glob and Parallel to UOp
@@ -168,6 +211,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
168211 return self .rewrite_Wrap (node )
169212 case wire .Measure () | qubit .Measure ():
170213 return self .rewrite_Measure (node )
214+ case wire .Reset () | qubit .Reset ():
215+ return self .rewrite_Reset (node )
171216 case _:
172217 return RewriteResult ()
173218
@@ -188,6 +233,7 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult:
188233
189234 # this is an SSAValue, need it to be the actual operator
190235 applied_op = apply_stmt .operator .owner
236+ assert isinstance (applied_op , op .stmts .Operator )
191237
192238 if isinstance (applied_op , op .stmts .Control ):
193239 return self .rewrite_Control (apply_stmt )
@@ -196,7 +242,25 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult:
196242 # but we can handle X, Y, Z, H, and S here just fine
197243 stim_1q_op = self .get_stim_1q_gate (applied_op )
198244
199- qubit_idx_ssas = self .insert_qubit_idx_ssa (apply_stmt = apply_stmt )
245+ # wire.Apply -> tuple of SSA -> AddressTuple
246+ # qubit.Apply -> list of qubits -> AddressTuple
247+ ## Both cases the statements follow the Stim semantics of
248+ ## 1QGate a b c d ....
249+
250+ if isinstance (apply_stmt , qubit .Apply ):
251+ address_attr = self .get_address (apply_stmt .qubits )
252+ qubit_idx_ssas = self .insert_qubit_idx_from_address (
253+ address = address_attr , stmt_to_insert_before = apply_stmt
254+ )
255+ elif isinstance (apply_stmt , wire .Apply ):
256+ qubit_idx_ssas = self .insert_qubit_idx_from_wire_ssa (
257+ wire_ssas = apply_stmt .inputs , stmt_to_insert_before = apply_stmt
258+ )
259+ else :
260+ raise TypeError (
261+ "Unsupported statement detected, only qubit.Apply and wire.Apply are permitted"
262+ )
263+
200264 stim_1q_stmt = stim_1q_op (targets = tuple (qubit_idx_ssas ))
201265 stim_1q_stmt .insert_before (apply_stmt )
202266
@@ -209,13 +273,16 @@ def rewrite_Control(
209273 # operator of Apply is a Control gate, enforce it's only asking for 1 control qubit,
210274 # and that the target of the control is X, Y, Z in squin
211275
212- ctrl_op : op .stmts .Control = apply_stmt_ctrl .operator .owner
276+ ctrl_op = apply_stmt_ctrl .operator .owner
277+ assert isinstance (ctrl_op , op .stmts .Control )
278+
213279 # enforce that n_controls is 1
214280
215281 ctrl_op_target_gate = ctrl_op .op .owner
282+ assert isinstance (ctrl_op_target_gate , op .stmts .Operator )
216283
217284 # should enforce that this is some multiple of 2
218- qubit_idx_ssas = self .insert_qubit_idx_ssa (apply_stmt = apply_stmt_ctrl )
285+ qubit_idx_ssas = self .insert_qubit_idx_after_apply (apply_stmt = apply_stmt_ctrl )
219286 # according to stim, final result can be:
220287 # CX 1 2 3 4 -> CX(1, targ=2), CX(3, targ=4)
221288 target_qubits = []
@@ -254,26 +321,26 @@ def rewrite_Measure(
254321 if isinstance (measure_stmt , qubit .Measure ):
255322 qubit_ilist_ssa = measure_stmt .qubits
256323 # qubits are in an ilist which makes up an AddressTuple
257- address_tuple : AddressTuple = self .get_address (qubit_ilist_ssa ).address
258- qubit_idx_ssas = []
259- for qubit_address in address_tuple :
260- qubit_idx = qubit_address .data
261- qubit_idx_stmt = py .constant .Constant (qubit_idx )
262- qubit_idx_stmt .insert_before (measure_stmt )
263- qubit_idx_ssas .append (qubit_idx_stmt .result )
264- qubit_idx_ssas = tuple (qubit_idx_ssas )
324+ address_attr = self .get_address (qubit_ilist_ssa )
265325
266326 elif isinstance (measure_stmt , wire .Measure ):
327+ # Wire Terminator, should kill the existence of
328+ # the wire here so DCE can sweep up the rest like with rewriting wrap
267329 wire_ssa = measure_stmt .wire
268- wire_address : AddressWire = self .get_address (wire_ssa ). address
330+ address_attr = self .get_address (wire_ssa )
269331
270- qubit_idx = wire_address .origin_qubit .data
271- qubit_idx_stmt = py .constant .Constant (qubit_idx )
272- qubit_idx_stmt .insert_before (measure_stmt )
273- qubit_idx_ssas = (qubit_idx_stmt .result ,)
332+ # DCE can't remove the old measure_stmt for both wire and qubit versions
333+ # because of the fact it has a result that can be depended on by other statements
334+ # whereas Stim Measure has no such notion
274335
275336 else :
276- return RewriteResult ()
337+ raise TypeError (
338+ "unsupported Statement, only qubit.Measure and wire.Measure are supported"
339+ )
340+
341+ qubit_idx_ssas = self .insert_qubit_idx_from_address (
342+ address = address_attr , stmt_to_insert_before = measure_stmt
343+ )
277344
278345 prob_noise_stmt = py .constant .Constant (0.0 )
279346 stim_measure_stmt = stim .collapse .MZ (
@@ -284,3 +351,43 @@ def rewrite_Measure(
284351 stim_measure_stmt .insert_before (measure_stmt )
285352
286353 return RewriteResult (has_done_something = True )
354+
355+ def rewrite_Reset (self , reset_stmt : qubit .Reset | wire .Reset ) -> RewriteResult :
356+ """
357+ qubit.Reset(ilist of qubits) -> nothing
358+ # safe to delete the statement afterwards, no depending results
359+ # DCE could probably do this automatically?
360+
361+ wire.Reset(single wire) -> new wire
362+ # DO NOT DELETE
363+
364+ # assume RZ, but could extend to RY and RX later
365+ Stim RZ(targets = tuple[int of SSAVals])
366+ """
367+
368+ if isinstance (reset_stmt , qubit .Reset ):
369+ qubit_ilist_ssa = reset_stmt .qubits
370+ # qubits are in an ilist which makes up an AddressTuple
371+ address_attr = self .get_address (qubit_ilist_ssa )
372+ qubit_idx_ssas = self .insert_qubit_idx_from_address (
373+ address = address_attr , stmt_to_insert_before = reset_stmt
374+ )
375+ elif isinstance (reset_stmt , wire .Reset ):
376+ address_attr = self .get_address (reset_stmt .wire )
377+ qubit_idx_ssas = self .insert_qubit_idx_from_address (
378+ address = address_attr , stmt_to_insert_before = reset_stmt
379+ )
380+ else :
381+ raise TypeError (
382+ "unsupported statement, only qubit.Reset and wire.Reset are supported"
383+ )
384+
385+ stim_rz_stmt = stim .collapse .stmts .RZ (targets = qubit_idx_ssas )
386+ stim_rz_stmt .insert_before (reset_stmt )
387+
388+ return RewriteResult (has_done_something = True )
389+
390+ def rewrite_MeasureAndReset (
391+ self , meas_and_reset_stmt : qubit .MeasureAndReset | wire .MeasureAndReset
392+ ):
393+ pass
0 commit comments