1- from typing import Dict , Tuple
1+ from typing import Tuple
22from dataclasses import dataclass
33
44from kirin .ir import SSAValue , Statement
5- from kirin .analysis import const
6- from kirin .dialects import py
5+ from kirin .dialects import py , ilist
76from kirin .rewrite .abc import RewriteRule , RewriteResult
87
98from bloqade .squin import op , wire , noise as squin_noise , qubit
109from bloqade .stim .dialects import noise as stim_noise
1110from bloqade .stim .rewrite .util import (
11+ get_const_value ,
1212 create_wire_passthrough ,
1313 insert_qubit_idx_after_apply ,
1414)
1717@dataclass
1818class SquinNoiseToStim (RewriteRule ):
1919
20- cp_results : Dict [SSAValue , const .Result ]
21-
2220 def rewrite_Statement (self , node : Statement ) -> RewriteResult :
2321 match node :
2422 case qubit .Apply () | qubit .Broadcast ():
@@ -67,7 +65,7 @@ def rewrite_PauliError(
6765 assert isinstance (squin_channel , squin_noise .stmts .PauliError )
6866 basis = squin_channel .basis .owner
6967 assert isinstance (basis , op .stmts .PauliOp )
70- p = self . cp_results . get ( squin_channel .p ). data
68+ p = get_const_value ( float , squin_channel .p )
7169
7270 p_stmt = py .Constant (p )
7371 p_stmt .insert_before (stmt )
@@ -90,7 +88,7 @@ def rewrite_SingleQubitPauliChannel(
9088 squin_channel = stmt .operator .owner
9189 assert isinstance (squin_channel , squin_noise .stmts .SingleQubitPauliChannel )
9290
93- params = self . cp_results . get ( squin_channel .params ). data
91+ params = get_const_value ( ilist . IList , squin_channel .params )
9492 new_stmts = [
9593 p_x := py .Constant (params [0 ]),
9694 p_y := py .Constant (params [1 ]),
@@ -117,7 +115,7 @@ def rewrite_TwoQubitPauliChannel(
117115 squin_channel = stmt .operator .owner
118116 assert isinstance (squin_channel , squin_noise .stmts .TwoQubitPauliChannel )
119117
120- params = self . cp_results . get ( squin_channel .params ). data
118+ params = get_const_value ( ilist . IList , squin_channel .params )
121119 param_stmts = [py .Constant (p ) for p in params ]
122120 for param_stmt in param_stmts :
123121 param_stmt .insert_before (stmt )
@@ -141,3 +139,20 @@ def rewrite_TwoQubitPauliChannel(
141139 pzz = param_stmts [14 ].result ,
142140 )
143141 return stim_stmt
142+
143+ def rewrite_Depolarize (
144+ self ,
145+ stmt : qubit .Apply | qubit .Broadcast | wire .Broadcast | wire .Apply ,
146+ qubit_idx_ssas : Tuple [SSAValue ],
147+ ) -> Statement :
148+ """Rewrite squin.noise.Depolarize to stim.Depolarize1."""
149+
150+ squin_channel = stmt .operator .owner
151+ assert isinstance (squin_channel , squin_noise .stmts .Depolarize )
152+
153+ p = get_const_value (float , squin_channel .p )
154+ p_stmt = py .Constant (p )
155+ p_stmt .insert_before (stmt )
156+
157+ stim_stmt = stim_noise .Depolarize1 (targets = qubit_idx_ssas , p = p_stmt .result )
158+ return stim_stmt
0 commit comments