11from kirin import ir
2- from kirin .dialects import py , func
2+ from kirin .dialects import func
33from kirin .rewrite .abc import RewriteRule , RewriteResult
44
55from bloqade import squin
66from bloqade .qasm2 .dialects .noise import stmts as noise_stmts
77
8+ from .util import num_to_py_constant
89
9- class QASM2NoiseToSquin (RewriteRule ):
10-
11- def rewrite_Statement (self , node : ir .Statement ) -> RewriteResult :
12-
13- match node :
14- case noise_stmts .AtomLossChannel ():
15- return self .rewrite_AtomLossChannel (node )
16- case noise_stmts .PauliChannel ():
17- return self .rewrite_PauliChannel (node )
18- case noise_stmts .CZPauliChannel ():
19- return self .rewrite_CZPauliChannel (node )
20- case _:
21- return RewriteResult ()
22-
23- return RewriteResult ()
10+ NOISE_TO_SQUIN_MAP = {
11+ noise_stmts .AtomLossChannel : squin .broadcast .qubit_loss ,
12+ noise_stmts .PauliChannel : squin .broadcast .single_qubit_pauli_channel ,
13+ }
2414
25- def rewrite_AtomLossChannel (
26- self , stmt : noise_stmts .AtomLossChannel
27- ) -> RewriteResult :
2815
29- qargs = stmt .qargs
30- # this is a raw float, not in SSA form yet!
31- prob = stmt .prob
32- prob_stmt = py .Constant (value = prob )
33- prob_stmt .insert_before (stmt )
34-
35- invoke_loss_stmt = func .Invoke (
36- callee = squin .broadcast .qubit_loss ,
37- inputs = (prob_stmt .result , qargs ),
38- )
39-
40- stmt .replace_by (invoke_loss_stmt )
41-
42- return RewriteResult (has_done_something = True )
43-
44- def rewrite_PauliChannel (self , stmt : noise_stmts .PauliChannel ) -> RewriteResult :
45-
46- qargs = stmt .qargs
47- p_x = stmt .px
48- p_y = stmt .py
49- p_z = stmt .pz
50-
51- probs = [p_x , p_y , p_z ]
52- probs_ssas = []
16+ class QASM2NoiseToSquin (RewriteRule ):
5317
54- for prob in probs :
55- prob_stmt = py .Constant (value = prob )
56- prob_stmt .insert_before (stmt )
57- probs_ssas .append (prob_stmt .result )
18+ def rewrite_Statement (self , node : ir .Statement ) -> RewriteResult :
5819
59- invoke_pauli_channel_stmt = func .Invoke (
60- callee = squin .broadcast .single_qubit_pauli_channel ,
61- inputs = (* probs_ssas , qargs ),
20+ if isinstance (node , noise_stmts .AtomLossChannel ):
21+ qargs = node .qargs
22+ prob = node .prob
23+ prob_ssas = num_to_py_constant ([prob ], stmt_to_insert_before = node )
24+ elif isinstance (node , noise_stmts .PauliChannel ):
25+ qargs = node .qargs
26+ p_x = node .px
27+ p_y = node .py
28+ p_z = node .pz
29+ prob_ssas = num_to_py_constant ([p_x , p_y , p_z ], stmt_to_insert_before = node )
30+ elif isinstance (node , noise_stmts .CZPauliChannel ):
31+ return self .rewrite_CZPauliChannel (node )
32+ else :
33+ return RewriteResult ()
34+
35+ squin_noise_stmt = NOISE_TO_SQUIN_MAP [type (node )]
36+ invoke_stmt = func .Invoke (
37+ callee = squin_noise_stmt ,
38+ inputs = (* prob_ssas , qargs ),
6239 )
63-
64- stmt .replace_by (invoke_pauli_channel_stmt )
40+ node .replace_by (invoke_stmt )
6541 return RewriteResult (has_done_something = True )
6642
6743 def rewrite_CZPauliChannel (self , stmt : noise_stmts .CZPauliChannel ) -> RewriteResult :
@@ -78,11 +54,8 @@ def rewrite_CZPauliChannel(self, stmt: noise_stmts.CZPauliChannel) -> RewriteRes
7854
7955 error_probs = [px_ctrl , py_ctrl , pz_ctrl , px_qarg , py_qarg , pz_qarg ]
8056 # first half of entries for control qubits, other half for targets
81- error_prob_ssas = []
82- for error_prob in error_probs :
83- error_prob_stmt = py .Constant (value = error_prob )
84- error_prob_stmt .insert_before (stmt )
85- error_prob_ssas .append (error_prob_stmt .result )
57+
58+ error_prob_ssas = num_to_py_constant (error_probs , stmt_to_insert_before = stmt )
8659
8760 ctrl_pauli_channel_invoke = func .Invoke (
8861 callee = squin .broadcast .single_qubit_pauli_channel ,
0 commit comments