33
44from kirin import ir
55from kirin .rewrite import abc
6- from kirin .analysis import const
76from kirin .dialects import ilist
87
98from bloqade .analysis import address
@@ -20,28 +19,24 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
2019 return abc .RewriteResult ()
2120
2221 qargs = node .qargs
23- qarg_addresses = self .address_analysis .get (qargs , None )
22+ qargs_address = self .address_analysis .get (qargs , address . Unknown () )
2423
25- if isinstance (qarg_addresses , address .AddressReg ):
26- # NOTE: we only have an AddressReg if it's an entire register, definitely rewrite that
27- return self ._rewrite_parallel_to_glob (node )
28-
29- if not isinstance (qarg_addresses , address .AddressTuple ):
24+ if not isinstance (qargs_address , address .AddressReg ):
3025 return abc .RewriteResult ()
3126
32- idxs , qreg = self ._find_qreg (qargs .owner , set () )
27+ qregs = self ._get_all_qreg (qargs .owner )
3328
34- if qreg is None :
35- # NOTE: no unique register found
29+ if len (qregs ) != 1 :
3630 return abc .RewriteResult ()
3731
38- if not isinstance (hint := qreg .n_qubits .hints .get ("const" ), const .Value ):
39- # NOTE: non-constant number of qubits
32+ qreg = next (iter (qregs ))
33+
34+ qreg_address = self .address_analysis .get (qreg , address .Unknown ())
35+
36+ if not isinstance (qreg_address , address .AddressReg ):
4037 return abc .RewriteResult ()
4138
42- n = hint .data
43- if len (idxs ) != n :
44- # NOTE: not all qubits of the register are there
39+ if set (qargs_address .data ) != set (qreg_address .data ):
4540 return abc .RewriteResult ()
4641
4742 return self ._rewrite_parallel_to_glob (node )
@@ -53,6 +48,24 @@ def _rewrite_parallel_to_glob(node: parallel.UGate) -> abc.RewriteResult:
5348 node .replace_by (global_u )
5449 return abc .RewriteResult (has_done_something = True )
5550
51+ @staticmethod
52+ def _get_all_qreg (owner : ir .Statement | ir .Block ):
53+ stack = [owner ]
54+ qregs : set [ir .SSAValue ] = set ()
55+ while stack :
56+ current = stack .pop ()
57+
58+ if isinstance (current , core .stmts .QRegGet ):
59+ stack .append (current .reg .owner )
60+ elif isinstance (current , ilist .New ):
61+ for val in current .values :
62+ stack .append (val .owner )
63+
64+ elif isinstance (current , core .QRegNew ):
65+ qregs .add (current .result )
66+
67+ return qregs
68+
5669 @staticmethod
5770 def _find_qreg (
5871 qargs_owner : ir .Statement | ir .Block , idxs : set
0 commit comments