33from kirin import ir
44from kirin .passes import Pass
55from kirin .rewrite import Walk
6- from kirin .dialects import ilist
6+ from kirin .dialects import py , ilist
77from kirin .rewrite .abc import RewriteRule , RewriteResult
88
99from .stmts import (
1010 PPError ,
1111 QubitLoss ,
1212 Depolarize ,
1313 PauliError ,
14+ Depolarize2 ,
1415 NoiseChannel ,
1516 TwoQubitPauliChannel ,
1617 SingleQubitPauliChannel ,
@@ -58,6 +59,18 @@ def rewrite_single_qubit_pauli_channel(
5859 def rewrite_two_qubit_pauli_channel (
5960 self , node : TwoQubitPauliChannel
6061 ) -> RewriteResult :
62+ operator_list = self ._insert_two_qubit_paulis_before_node (node )
63+ stochastic_unitary = StochasticUnitaryChannel (
64+ operators = operator_list , probabilities = node .params
65+ )
66+
67+ node .replace_by (stochastic_unitary )
68+ return RewriteResult (has_done_something = True )
69+
70+ @staticmethod
71+ def _insert_two_qubit_paulis_before_node (
72+ node : TwoQubitPauliChannel | Depolarize2 ,
73+ ) -> ir .ResultValue :
6174 paulis = (Identity (sites = 1 ), X (), Y (), Z ())
6275 for op in paulis :
6376 op .insert_before (node )
@@ -71,12 +84,7 @@ def rewrite_two_qubit_pauli_channel(
7184 operators .append (op .result )
7285
7386 (operator_list := ilist .New (values = operators )).insert_before (node )
74- stochastic_unitary = StochasticUnitaryChannel (
75- operators = operator_list .result , probabilities = node .params
76- )
77-
78- node .replace_by (stochastic_unitary )
79- return RewriteResult (has_done_something = True )
87+ return operator_list .result
8088
8189 def rewrite_p_p_error (self , node : PPError ) -> RewriteResult :
8290 (operators := ilist .New (values = (node .op ,))).insert_before (node )
@@ -95,8 +103,14 @@ def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
95103 op .insert_before (node )
96104 operators .append (op .result )
97105
106+ # NOTE: need to divide the probability by 3 to get the correct total error rate
107+ (three := py .Constant (3 )).insert_before (node )
108+ (p_over_3 := py .Div (node .p , three .result )).insert_before (node )
109+
98110 (operator_list := ilist .New (values = operators )).insert_before (node )
99- (ps := ilist .New (values = [node .p for _ in range (3 )])).insert_before (node )
111+ (ps := ilist .New (values = [p_over_3 .result for _ in range (3 )])).insert_before (
112+ node
113+ )
100114
101115 stochastic_unitary = StochasticUnitaryChannel (
102116 operators = operator_list .result , probabilities = ps .result
@@ -105,6 +119,21 @@ def rewrite_depolarize(self, node: Depolarize) -> RewriteResult:
105119
106120 return RewriteResult (has_done_something = True )
107121
122+ def rewrite_depolarize2 (self , node : Depolarize2 ) -> RewriteResult :
123+ operator_list = self ._insert_two_qubit_paulis_before_node (node )
124+
125+ # NOTE: need to divide the probability by 15 to get the correct total error rate
126+ (fifteen := py .Constant (15 )).insert_before (node )
127+ (p_over_15 := py .Div (node .p , fifteen .result )).insert_before (node )
128+ (probs := ilist .New (values = [p_over_15 .result ] * 15 )).insert_before (node )
129+
130+ stochastic_unitary = StochasticUnitaryChannel (
131+ operators = operator_list , probabilities = probs .result
132+ )
133+ node .replace_by (stochastic_unitary )
134+
135+ return RewriteResult (has_done_something = True )
136+
108137
109138class RewriteNoiseStmts (Pass ):
110139 def unsafe_run (self , mt : ir .Method ):
0 commit comments