33from dataclasses import field , dataclass
44
55from kirin import ir
6- from kirin .rewrite import abc as rewrite_abc
76from kirin .dialects import py , ilist
7+ from kirin .rewrite .abc import RewriteRule
88from kirin .analysis .const import lattice
9+ from kirin .rewrite .result import RewriteResult
910
1011from bloqade .analysis import address
1112from bloqade .qasm2 .dialects import uop , core , parallel
1415
1516class MergePolicyABC (abc .ABC ):
1617 @abc .abstractmethod
17- def __call__ (self , node : ir .Statement ) -> rewrite_abc . RewriteResult :
18+ def __call__ (self , node : ir .Statement ) -> RewriteResult :
1819 pass
1920
2021 @classmethod
@@ -141,10 +142,10 @@ def from_analysis(
141142 group_numbers = group_numbers ,
142143 )
143144
144- def __call__ (self , node : ir .Statement ) -> rewrite_abc . RewriteResult :
145+ def __call__ (self , node : ir .Statement ) -> RewriteResult :
145146
146147 if node not in self .group_numbers :
147- return rewrite_abc . RewriteResult ()
148+ return RewriteResult ()
148149
149150 group_number = self .group_numbers [node ]
150151 group = self .merge_groups [group_number ]
@@ -157,9 +158,7 @@ def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
157158 if self .group_has_merged [group_number ]:
158159 node .delete ()
159160
160- return rewrite_abc .RewriteResult (
161- has_done_something = self .group_has_merged [group_number ]
162- )
161+ return RewriteResult (has_done_something = self .group_has_merged [group_number ])
163162
164163 def move_and_collect_qubit_list (
165164 self , qargs : List [ir .SSAValue ], node : ir .Statement
@@ -219,14 +218,14 @@ def rewrite_group_cz(self, node: ir.Statement, group: List[ir.Statement]):
219218 ctrls .append (stmt .ctrls )
220219 qargs .append (stmt .qargs )
221220 else :
222- return rewrite_abc . RewriteResult (has_done_something = False )
221+ return RewriteResult (has_done_something = False )
223222
224223 ctrls_values = self .move_and_collect_qubit_list (ctrls , node )
225224 qargs_values = self .move_and_collect_qubit_list (qargs , node )
226225
227226 if ctrls_values is None or qargs_values is None :
228227 # give up if we cannot determine the address or cannot move the qubits
229- return rewrite_abc . RewriteResult (has_done_something = False )
228+ return RewriteResult (has_done_something = False )
230229
231230 new_ctrls = ilist .New (values = ctrls_values )
232231 new_qargs = ilist .New (values = qargs_values )
@@ -238,7 +237,7 @@ def rewrite_group_cz(self, node: ir.Statement, group: List[ir.Statement]):
238237
239238 node .delete ()
240239
241- return rewrite_abc . RewriteResult (has_done_something = True )
240+ return RewriteResult (has_done_something = True )
242241
243242 def rewrite_group_U (self , node : ir .Statement , group : List [ir .Statement ]):
244243 return self .rewrite_group_u (node , group )
@@ -252,13 +251,13 @@ def rewrite_group_u(self, node: ir.Statement, group: List[ir.Statement]):
252251 elif isinstance (stmt , parallel .UGate ):
253252 qargs .append (stmt .qargs )
254253 else :
255- return rewrite_abc . RewriteResult (has_done_something = False )
254+ return RewriteResult (has_done_something = False )
256255
257256 assert isinstance (node , (uop .UGate , parallel .UGate ))
258257 qargs_values = self .move_and_collect_qubit_list (qargs , node )
259258
260259 if qargs_values is None :
261- return rewrite_abc . RewriteResult (has_done_something = False )
260+ return RewriteResult (has_done_something = False )
262261
263262 new_qargs = ilist .New (values = qargs_values )
264263 new_gate = parallel .UGate (
@@ -271,7 +270,7 @@ def rewrite_group_u(self, node: ir.Statement, group: List[ir.Statement]):
271270 new_gate .insert_before (node )
272271 node .delete ()
273272
274- return rewrite_abc . RewriteResult (has_done_something = True )
273+ return RewriteResult (has_done_something = True )
275274
276275 def rewrite_group_rz (self , node : ir .Statement , group : List [ir .Statement ]):
277276 qargs = []
@@ -282,14 +281,14 @@ def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]):
282281 elif isinstance (stmt , parallel .RZ ):
283282 qargs .append (stmt .qargs )
284283 else :
285- return rewrite_abc . RewriteResult (has_done_something = False )
284+ return RewriteResult (has_done_something = False )
286285
287286 assert isinstance (node , (uop .RZ , parallel .RZ ))
288287
289288 qargs_values = self .move_and_collect_qubit_list (qargs , node )
290289
291290 if qargs_values is None :
292- return rewrite_abc . RewriteResult (has_done_something = False )
291+ return RewriteResult (has_done_something = False )
293292
294293 new_qargs = ilist .New (values = qargs_values )
295294 new_gate = parallel .RZ (
@@ -300,7 +299,7 @@ def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]):
300299 new_gate .insert_before (node )
301300 node .delete ()
302301
303- return rewrite_abc . RewriteResult (has_done_something = True )
302+ return RewriteResult (has_done_something = True )
304303
305304 def rewrite_group_barrier (self , node : uop .Barrier , group : List [uop .Barrier ]):
306305 qargs = []
@@ -310,13 +309,13 @@ def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]):
310309 qargs_values = self .move_and_collect_qubit_list (qargs , node )
311310
312311 if qargs_values is None :
313- return rewrite_abc . RewriteResult (has_done_something = False )
312+ return RewriteResult (has_done_something = False )
314313
315314 new_node = uop .Barrier (qargs = qargs_values )
316315 new_node .insert_before (node )
317316 node .delete ()
318317
319- return rewrite_abc . RewriteResult (has_done_something = True )
318+ return RewriteResult (has_done_something = True )
320319
321320
322321class GreedyMixin (MergePolicyABC ):
@@ -385,11 +384,11 @@ class SimpleOptimalMergePolicy(OptimalMixIn, SimpleMergePolicy):
385384
386385
387386@dataclass
388- class UOpToParallelRule (rewrite_abc . RewriteRule ):
387+ class UOpToParallelRule (RewriteRule ):
389388 merge_rewriters : Dict [ir .Block | None , MergePolicyABC ]
390389
391- def rewrite_Statement (self , node : ir .Statement ) -> rewrite_abc . RewriteResult :
390+ def rewrite_Statement (self , node : ir .Statement ) -> RewriteResult :
392391 merge_rewriter = self .merge_rewriters .get (
393- node .parent_block , lambda _ : rewrite_abc . RewriteResult ()
392+ node .parent_block , lambda _ : RewriteResult ()
394393 )
395394 return merge_rewriter (node )
0 commit comments