55from kirin import ir
66from kirin .analysis .forward import Forward , ForwardFrame
77from kirin .lattice import EmptyLattice
8- from typing_extensions import Self
98
109from bloqade .lanes .layout .encoding import LocationAddress
1110
@@ -39,12 +38,15 @@ def compute_layout(
3938
4039@dataclass
4140class LayoutAnalysis (Forward ):
42- keys = ("circuit .layout" ,)
41+ keys = ("place .layout" ,)
4342 lattice = EmptyLattice
4443
4544 heuristic : LayoutHeuristicABC
4645 address_entries : dict [ir .SSAValue , address .Address ]
4746 all_qubits : tuple [int , ...] = field (init = False )
47+ thetas : dict [int , ir .SSAValue ] = field (default_factory = dict , init = False )
48+ phis : dict [int , ir .SSAValue ] = field (default_factory = dict , init = False )
49+ lams : dict [int , ir .SSAValue ] = field (default_factory = dict , init = False )
4850 stages : list [tuple [tuple [int , int ], ...]] = field (default_factory = list , init = False )
4951 global_address_stack : list [int ] = field (default_factory = list , init = False )
5052
@@ -60,9 +62,12 @@ def __post_init__(self) -> None:
6062 )
6163 super ().__post_init__ ()
6264
63- def initialize (self ) -> Self :
65+ def initialize (self ):
6466 self .stages .clear ()
6567 self .global_address_stack .clear ()
68+ self .thetas .clear ()
69+ self .phis .clear ()
70+ self .lams .clear ()
6671 return super ().initialize ()
6772
6873 def eval_stmt_fallback (self , frame , stmt ):
@@ -76,15 +81,36 @@ def add_stage(self, control: tuple[int, ...], target: tuple[int, ...]):
7681 def method_self (self , method : ir .Method ):
7782 return EmptyLattice .bottom ()
7883
79- def get_layout_no_raise (self , method : ir .Method ) -> tuple [LocationAddress , ...]:
84+ def process_results (self ):
85+ layout = self .heuristic .compute_layout (self .all_qubits , self .stages )
86+ init_locations = tuple (
87+ loc
88+ for qubit_id , loc in zip (self .all_qubits , layout )
89+ if qubit_id in self .thetas
90+ )
91+ thetas = tuple (
92+ self .thetas [qubit_id ]
93+ for qubit_id in self .all_qubits
94+ if qubit_id in self .thetas
95+ )
96+ phis = tuple (
97+ self .phis [qubit_id ] for qubit_id in self .all_qubits if qubit_id in self .phis
98+ )
99+ lams = tuple (
100+ self .lams [qubit_id ] for qubit_id in self .all_qubits if qubit_id in self .lams
101+ )
102+
103+ return layout , init_locations , thetas , phis , lams
104+
105+ def get_layout_no_raise (self , method : ir .Method ):
80106 """Get the layout for a given method."""
81107 self .run_no_raise (method )
82- return self .heuristic . compute_layout ( self . all_qubits , self . stages )
108+ return self .process_results ( )
83109
84- def get_layout (self , method : ir .Method ) -> tuple [ LocationAddress , ...] :
110+ def get_layout (self , method : ir .Method ):
85111 """Get the layout for a given method."""
86112 self .run (method )
87- return self .heuristic . compute_layout ( self . all_qubits , self . stages )
113+ return self .process_results ( )
88114
89115 def eval_fallback (self , frame : ForwardFrame , node : ir .Statement ):
90116 return tuple (EmptyLattice .bottom () for _ in node .results )
0 commit comments