@@ -18,12 +18,20 @@ class RewritePlaceOperations(abc.RewriteRule):
1818 This is a placeholder for the actual implementation.
1919 """
2020
21- def default_ (self , node : ir .Statement ) -> abc .RewriteResult :
22- return abc .RewriteResult ()
23-
2421 def rewrite_Statement (self , node : ir .Statement ) -> abc .RewriteResult :
22+ if not isinstance (
23+ node ,
24+ (
25+ gemini_stmts .TerminalLogicalMeasurement ,
26+ gemini_stmts .Initialize ,
27+ gate .CZ ,
28+ gate .R ,
29+ gate .Rz ,
30+ ),
31+ ):
32+ return abc .RewriteResult ()
2533 rewrite_method_name = f"rewrite_{ type (node ).__name__ } "
26- rewrite_method = getattr (self , rewrite_method_name , self . default_ )
34+ rewrite_method = getattr (self , rewrite_method_name )
2735 return rewrite_method (node )
2836
2937 def prep_region (self ) -> tuple [ir .Region , ir .Block , ir .SSAValue ]:
@@ -46,6 +54,25 @@ def construct_execute(
4654
4755 return place .StaticPlacement (qubits = qubits , body = body )
4856
57+ def rewrite_Initialize (self , node : gemini_stmts .Initialize ) -> abc .RewriteResult :
58+ if not isinstance (args_list := node .qubits .owner , ilist .New ):
59+ return abc .RewriteResult ()
60+
61+ inputs = args_list .values
62+ body , block , entry_state = self .prep_region ()
63+ gate_stmt = place .Initialize (
64+ entry_state ,
65+ phi = node .phi ,
66+ theta = node .theta ,
67+ lam = node .lam ,
68+ qubits = tuple (range (len (inputs ))),
69+ )
70+ node .replace_by (
71+ self .construct_execute (gate_stmt , qubits = inputs , body = body , block = block )
72+ )
73+
74+ return abc .RewriteResult (has_done_something = True )
75+
4976 def rewrite_TerminalLogicalMeasurement (
5077 self , node : gemini_stmts .TerminalLogicalMeasurement
5178 ) -> abc .RewriteResult :
@@ -82,13 +109,11 @@ def rewrite_CZ(self, node: gate.CZ) -> abc.RewriteResult:
82109 return abc .RewriteResult ()
83110
84111 all_qubits = tuple (range (len (targets ) + len (controls )))
85- n_controls = len (controls )
86112
87113 body , block , entry_state = self .prep_region ()
88114 stmt = place .CZ (
89115 entry_state ,
90- controls = all_qubits [:n_controls ],
91- targets = all_qubits [n_controls :],
116+ qubits = all_qubits ,
92117 )
93118
94119 node .replace_by (
@@ -155,27 +180,6 @@ class MergePlacementRegions(abc.RewriteRule):
155180 merge_heuristic : Callable [[ir .Region , ir .Region ], bool ] = _default_merge_heuristic
156181 """Heuristic function to decide whether to merge two circuit regions."""
157182
158- def remap_qubits (
159- self ,
160- curr_state : ir .SSAValue ,
161- node : place .R | place .Rz | place .CZ | place .EndMeasure ,
162- input_map : dict [int , int ],
163- ):
164- if isinstance (node , place .CZ ):
165- return place .CZ (
166- curr_state ,
167- targets = tuple (input_map [i ] for i in node .targets ),
168- controls = tuple (input_map [i ] for i in node .controls ),
169- )
170- else :
171- return node .from_stmt (
172- node ,
173- args = (curr_state , * node .args [1 :]),
174- attributes = {
175- "qubits" : ir .PyAttr (tuple (input_map [i ] for i in node .qubits ))
176- },
177- )
178-
179183 def rewrite_Statement (self , node : ir .Statement ) -> abc .RewriteResult :
180184 if not (
181185 isinstance (node , place .StaticPlacement )
@@ -206,8 +210,18 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
206210 curr_yield .delete ()
207211
208212 for stmt in next_node .body .blocks [0 ].stmts :
209- if isinstance (stmt , (place .R , place .Rz , place .CZ , place .EndMeasure )):
210- remapped_stmt = self .remap_qubits (curr_state , stmt , new_input_map )
213+ if isinstance (
214+ stmt , (place .R , place .Rz , place .CZ , place .EndMeasure , place .Initialize )
215+ ):
216+ remapped_stmt = stmt .from_stmt (
217+ stmt ,
218+ args = (curr_state , * stmt .args [1 :]),
219+ attributes = {
220+ "qubits" : ir .PyAttr (
221+ tuple (new_input_map [i ] for i in stmt .qubits )
222+ )
223+ },
224+ )
211225 curr_state = remapped_stmt .results [0 ]
212226 new_block .stmts .append (remapped_stmt )
213227 for old_result , new_result in zip (
0 commit comments