66from bloqade .squin import op , wire
77from bloqade .squin .rewrite .wrap_analysis import AddressAttribute
88from bloqade .squin .rewrite .stim_rewrite_util import (
9+ SQUIN_STIM_GATE_MAPPING ,
910 rewrite_Control ,
10- get_stim_1q_gate ,
1111 are_sites_compatible ,
12+ is_measure_result_used ,
1213 insert_qubit_idx_from_address ,
1314 insert_qubit_idx_from_wire_ssa ,
1415)
1718class SquinWireToStim (RewriteRule ):
1819
1920 def rewrite_Statement (self , node : ir .Statement ) -> RewriteResult :
20-
21- rewrite_methods = {
22- wire .Apply : self .rewrite_Apply_and_Broadcast ,
23- wire .Broadcast : self .rewrite_Apply_and_Broadcast ,
24- wire .Wrap : self .rewrite_Wrap ,
25- wire .Measure : self .rewrite_Measure ,
26- wire .Reset : self .rewrite_Reset ,
27- wire .MeasureAndReset : self .rewrite_MeasureAndReset ,
28- }
29-
30- rewrite_method = rewrite_methods .get (type (node ))
31- if rewrite_method is None :
32- return RewriteResult ()
33-
34- return rewrite_method (node )
21+ match node :
22+ case wire .Apply () | wire .Broadcast ():
23+ return self .rewrite_Apply_and_Broadcast (node )
24+ case wire .Measure ():
25+ return self .rewrite_Measure (node )
26+ case wire .Reset ():
27+ return self .rewrite_Reset (node )
28+ case wire .MeasureAndReset ():
29+ return self .rewrite_MeasureAndReset (node )
30+ case _:
31+ return RewriteResult ()
3532
3633 def rewrite_Apply_and_Broadcast (
3734 self , stmt : wire .Apply | wire .Broadcast
@@ -47,7 +44,7 @@ def rewrite_Apply_and_Broadcast(
4744 if isinstance (applied_op , op .stmts .Control ):
4845 return rewrite_Control (stmt )
4946
50- stim_1q_op = get_stim_1q_gate ( applied_op )
47+ stim_1q_op = SQUIN_STIM_GATE_MAPPING . get ( type ( applied_op ) )
5148 if stim_1q_op is None :
5249 return RewriteResult ()
5350
@@ -69,21 +66,11 @@ def rewrite_Apply_and_Broadcast(
6966
7067 return RewriteResult (has_done_something = True )
7168
72- def rewrite_Wrap (self , wrap_stmt : wire .Wrap ) -> RewriteResult :
73-
74- # structure at this point should be:
75- ## w = wire.Unwrap(wire)
76- ## wire.Wrap(qubit, w)
77-
78- wire_origin_stmt = wrap_stmt .wire .owner
79- if isinstance (wire_origin_stmt , wire .Unwrap ):
80- wrap_stmt .delete ()
81- return RewriteResult (has_done_something = True )
82-
83- return RewriteResult ()
84-
8569 def rewrite_Measure (self , measure_stmt : wire .Measure ) -> RewriteResult :
8670
71+ if is_measure_result_used (measure_stmt ):
72+ return RewriteResult ()
73+
8774 wire_ssa = measure_stmt .wire
8875 address_attr = wire_ssa .hints .get ("address" )
8976 if address_attr is None :
@@ -126,6 +113,9 @@ def rewrite_Reset(self, reset_stmt: wire.Reset) -> RewriteResult:
126113
127114 def rewrite_MeasureAndReset (self , meas_and_reset_stmt : wire .MeasureAndReset ):
128115
116+ if is_measure_result_used (meas_and_reset_stmt ):
117+ return RewriteResult ()
118+
129119 address_attr = meas_and_reset_stmt .wire .hints .get ("address" )
130120 if address_attr is None :
131121 return RewriteResult ()
@@ -141,6 +131,7 @@ def rewrite_MeasureAndReset(self, meas_and_reset_stmt: wire.MeasureAndReset):
141131 stim_rz_stmt = stim .collapse .RZ (
142132 targets = qubit_idx_ssas ,
143133 )
134+
144135 error_p_stmt .insert_before (meas_and_reset_stmt )
145136 stim_mz_stmt .insert_before (meas_and_reset_stmt )
146137 meas_and_reset_stmt .replace_by (stim_rz_stmt )
0 commit comments