@@ -82,7 +82,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
8282@dataclass
8383class _SquinToStim (RewriteRule ):
8484
85- def get_address (self , value : ir .SSAValue ) -> AddressAttribute :
85+ def get_address_attr (self , value : ir .SSAValue ) -> AddressAttribute :
8686
8787 try :
8888 address_attr = value .hints ["address" ]
@@ -91,7 +91,7 @@ def get_address(self, value: ir.SSAValue) -> AddressAttribute:
9191 except KeyError :
9292 raise KeyError (f"The address analysis hint for { value } does not exist" )
9393
94- def get_sites (self , value : ir .SSAValue ):
94+ def get_sites_attr (self , value : ir .SSAValue ):
9595 try :
9696 return value .hints ["sites" ]
9797 except KeyError :
@@ -157,7 +157,7 @@ def insert_qubit_idx_from_wire_ssa(
157157 ) -> tuple [ir .SSAValue , ...]:
158158 qubit_idx_ssas = []
159159 for wire_ssa in wire_ssas :
160- address_attribute = self .get_address (wire_ssa ) # get AddressWire
160+ address_attribute = self .get_address_attr (wire_ssa ) # get AddressWire
161161 # get parent qubit idx
162162 wire_address = address_attribute .address
163163 assert isinstance (wire_address , AddressWire )
@@ -178,7 +178,7 @@ def insert_qubit_idx_after_apply(
178178
179179 if isinstance (apply_stmt , qubit .Apply ):
180180 qubits = apply_stmt .qubits
181- address_attribute : AddressAttribute = self .get_address (qubits )
181+ address_attribute : AddressAttribute = self .get_address_attr (qubits )
182182 # Should get an AddressTuple out of the address stored in attribute
183183 return self .insert_qubit_idx_from_address (
184184 address = address_attribute , stmt_to_insert_before = apply_stmt
@@ -213,6 +213,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
213213 return self .rewrite_Measure (node )
214214 case wire .Reset () | qubit .Reset ():
215215 return self .rewrite_Reset (node )
216+ case wire .MeasureAndReset () | qubit .MeasureAndReset ():
217+ return self .rewrite_MeasureAndReset (node )
216218 case _:
217219 return RewriteResult ()
218220
@@ -248,7 +250,7 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult:
248250 ## 1QGate a b c d ....
249251
250252 if isinstance (apply_stmt , qubit .Apply ):
251- address_attr = self .get_address (apply_stmt .qubits )
253+ address_attr = self .get_address_attr (apply_stmt .qubits )
252254 qubit_idx_ssas = self .insert_qubit_idx_from_address (
253255 address = address_attr , stmt_to_insert_before = apply_stmt
254256 )
@@ -321,13 +323,13 @@ def rewrite_Measure(
321323 if isinstance (measure_stmt , qubit .Measure ):
322324 qubit_ilist_ssa = measure_stmt .qubits
323325 # qubits are in an ilist which makes up an AddressTuple
324- address_attr = self .get_address (qubit_ilist_ssa )
326+ address_attr = self .get_address_attr (qubit_ilist_ssa )
325327
326328 elif isinstance (measure_stmt , wire .Measure ):
327329 # Wire Terminator, should kill the existence of
328330 # the wire here so DCE can sweep up the rest like with rewriting wrap
329331 wire_ssa = measure_stmt .wire
330- address_attr = self .get_address (wire_ssa )
332+ address_attr = self .get_address_attr (wire_ssa )
331333
332334 # DCE can't remove the old measure_stmt for both wire and qubit versions
333335 # because of the fact it has a result that can be depended on by other statements
@@ -368,12 +370,12 @@ def rewrite_Reset(self, reset_stmt: qubit.Reset | wire.Reset) -> RewriteResult:
368370 if isinstance (reset_stmt , qubit .Reset ):
369371 qubit_ilist_ssa = reset_stmt .qubits
370372 # qubits are in an ilist which makes up an AddressTuple
371- address_attr = self .get_address (qubit_ilist_ssa )
373+ address_attr = self .get_address_attr (qubit_ilist_ssa )
372374 qubit_idx_ssas = self .insert_qubit_idx_from_address (
373375 address = address_attr , stmt_to_insert_before = reset_stmt
374376 )
375377 elif isinstance (reset_stmt , wire .Reset ):
376- address_attr = self .get_address (reset_stmt .wire )
378+ address_attr = self .get_address_attr (reset_stmt .wire )
377379 qubit_idx_ssas = self .insert_qubit_idx_from_address (
378380 address = address_attr , stmt_to_insert_before = reset_stmt
379381 )
@@ -390,4 +392,43 @@ def rewrite_Reset(self, reset_stmt: qubit.Reset | wire.Reset) -> RewriteResult:
390392 def rewrite_MeasureAndReset (
391393 self , meas_and_reset_stmt : qubit .MeasureAndReset | wire .MeasureAndReset
392394 ):
393- pass
395+ """
396+ qubit.MeasureAndReset(qubits) -> result
397+ Could be translated (roughly equivalent) to
398+
399+ stim.MZ(tuple[SSAvals for ints])
400+ stim.RZ(tuple[SSAvals for ints])
401+
402+ Stim does have MRZ, might be more reflective of what we want/
403+ lines up the semantics better
404+
405+ """
406+
407+ if isinstance (meas_and_reset_stmt , qubit .MeasureAndReset ):
408+
409+ address_attr = self .get_address_attr (meas_and_reset_stmt .qubits )
410+ qubit_idx_ssas = self .insert_qubit_idx_from_address (
411+ address = address_attr , stmt_to_insert_before = meas_and_reset_stmt
412+ )
413+
414+ elif isinstance (meas_and_reset_stmt , wire .MeasureAndReset ):
415+ address_attr = self .get_address_attr (meas_and_reset_stmt .wire )
416+ qubit_idx_ssas = self .insert_qubit_idx_from_address (
417+ address_attr , stmt_to_insert_before = meas_and_reset_stmt
418+ )
419+
420+ else :
421+ raise TypeError (
422+ "Unsupported statement detected, only qubit.MeasureAndReset and wire.MeasureAndReset are supported"
423+ )
424+
425+ error_p_stmt = py .Constant (0.0 )
426+ stim_mz_stmt = stim .collapse .MZ (targets = qubit_idx_ssas , p = error_p_stmt .result )
427+ stim_rz_stmt = stim .collapse .RZ (
428+ targets = qubit_idx_ssas ,
429+ )
430+ error_p_stmt .insert_before (meas_and_reset_stmt )
431+ stim_mz_stmt .insert_before (meas_and_reset_stmt )
432+ stim_rz_stmt .insert_before (meas_and_reset_stmt )
433+
434+ return RewriteResult (has_done_something = True )
0 commit comments