11from kirin import ir
2- from kirin .dialects import func
2+ from kirin .dialects import func , py
33from kirin .rewrite import abc
44
55from bloqade .shuttle .dialects import path , schedule
@@ -24,6 +24,31 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
2424 return abc .RewriteResult (has_done_something = True )
2525
2626
27+ class RewriteAutoInvoke (abc .RewriteRule ):
28+ def rewrite_Statement (self , node : ir .Statement ) -> abc .RewriteResult :
29+ if not isinstance (node .parent_stmt , schedule .Auto ):
30+ return abc .RewriteResult ()
31+
32+ if isinstance (node , func .Invoke ):
33+ (callee_stmt := py .Constant (node .callee )).insert_before (node )
34+ callee_ssa = callee_stmt .result
35+ elif isinstance (node , func .Call ):
36+ callee_ssa = node .callee
37+ else :
38+ return abc .RewriteResult ()
39+
40+ (tweezer_task := schedule .NewTweezerTask (move_fn = callee_ssa )).insert_before (
41+ node
42+ )
43+ (path .Gen (tweezer_task .result , node .inputs , kwargs = node .kwargs )).insert_before (
44+ node
45+ )
46+
47+ node .delete ()
48+
49+ return abc .RewriteResult (has_done_something = True )
50+
51+
2752class RewriteScheduleRegion (abc .RewriteRule ):
2853 CLASSES = {
2954 schedule .Auto : path .Auto ,
0 commit comments