Skip to content

Commit dc9a5bb

Browse files
weinbe58david-pl
andauthored
Fix CallGraphPass issues in Kirin 0.20 (#586)
In this PR I port the `CallGraphPass` originally in Kirin 0.17 to the passes here in bloqade-circuit. I have also updated them to use the new Kirin APIs so they are not compatible with Kirin 0.20. Unfortunately the tests are currently breaking because of some import issues so the CI will not be able to pass for the tests until those are fixed. Co-authored-by: David Plankensteiner <[email protected]>
1 parent 63bfdff commit dc9a5bb

File tree

3 files changed

+125
-63
lines changed

3 files changed

+125
-63
lines changed
Lines changed: 4 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from itertools import chain
2-
from dataclasses import field, dataclass
32

4-
from kirin import ir, passes, rewrite
3+
from kirin import ir, rewrite
54
from kirin.dialects import py, func
65
from kirin.rewrite.abc import RewriteRule, RewriteResult
7-
from kirin.passes.callgraph import CallGraphPass, ReplaceMethods
86
from kirin.analysis.callgraph import CallGraph
97

108
from bloqade.native import kernel, broadcast
119
from bloqade.squin.gate import stmts, dialect as gate_dialect
10+
from bloqade.rewrite.passes import CallGraphPass, UpdateDialectsOnCallGraph
1211

1312

1413
class GateRule(RewriteRule):
@@ -46,63 +45,6 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
4645
return RewriteResult(has_done_something=True)
4746

4847

49-
@dataclass
50-
class UpdateDialectsOnCallGraph(passes.Pass):
51-
"""Update All dialects on the call graph to a new set of dialects given to this pass.
52-
53-
Usage:
54-
pass_ = UpdateDialectsOnCallGraph(rule=rule, dialects=new_dialects)
55-
pass_(some_method)
56-
57-
Note: This pass does not update the dialects of the input method, but copies
58-
all other methods invoked within it before updating their dialects.
59-
60-
"""
61-
62-
fold_pass: passes.Fold = field(init=False)
63-
64-
def __post_init__(self):
65-
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
66-
67-
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
68-
mt_map = {}
69-
70-
cg = CallGraph(mt)
71-
72-
all_methods = set(sum(map(tuple, cg.defs.values()), ()))
73-
for original_mt in all_methods:
74-
if original_mt is mt:
75-
new_mt = original_mt
76-
else:
77-
new_mt = original_mt.similar(self.dialects)
78-
mt_map[original_mt] = new_mt
79-
80-
result = RewriteResult()
81-
82-
for _, new_mt in mt_map.items():
83-
result = (
84-
rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code).join(result)
85-
)
86-
self.fold_pass(new_mt)
87-
88-
return result
89-
90-
91-
@dataclass
92-
class SquinToNativePass(passes.Pass):
93-
94-
call_graph_pass: CallGraphPass = field(init=False)
95-
96-
def __post_init__(self):
97-
rule = rewrite.Walk(GateRule())
98-
self.call_graph_pass = CallGraphPass(
99-
self.dialects, rule, no_raise=self.no_raise
100-
)
101-
102-
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
103-
return self.call_graph_pass.unsafe_run(mt)
104-
105-
10648
class SquinToNative:
10749
"""A Target that converts Squin gates to native gates."""
10850

@@ -126,11 +68,10 @@ def emit(self, mt: ir.Method, *, no_raise=True) -> ir.Method:
12668

12769
out = mt.similar(new_dialects)
12870
UpdateDialectsOnCallGraph(new_dialects, no_raise=no_raise)(out)
129-
SquinToNativePass(new_dialects, no_raise=no_raise)(out)
71+
CallGraphPass(new_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(out)
13072
# verify all kernels in the callgraph
13173
new_callgraph = CallGraph(out)
132-
all_kernels = (ker for kers in new_callgraph.defs.values() for ker in kers)
133-
for ker in all_kernels:
74+
for ker in new_callgraph.edges.keys():
13475
ker.verify()
13576

13677
return out
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
1+
from .callgraph import (
2+
CallGraphPass as CallGraphPass,
3+
ReplaceMethods as ReplaceMethods,
4+
UpdateDialectsOnCallGraph as UpdateDialectsOnCallGraph,
5+
)
16
from .aggressive_unroll import AggressiveUnroll as AggressiveUnroll
27
from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from dataclasses import field, dataclass
2+
3+
from kirin import ir, passes, rewrite
4+
from kirin.analysis import CallGraph
5+
from kirin.rewrite.abc import RewriteRule, RewriteResult
6+
from kirin.dialects.func.stmts import Invoke
7+
8+
9+
@dataclass
10+
class ReplaceMethods(RewriteRule):
11+
new_symbols: dict[ir.Method, ir.Method]
12+
13+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
14+
if (
15+
not isinstance(node, Invoke)
16+
or (new_callee := self.new_symbols.get(node.callee)) is None
17+
):
18+
return RewriteResult()
19+
20+
node.replace_by(
21+
Invoke(
22+
inputs=node.inputs,
23+
callee=new_callee,
24+
purity=node.purity,
25+
)
26+
)
27+
28+
return RewriteResult(has_done_something=True)
29+
30+
31+
@dataclass
32+
class UpdateDialectsOnCallGraph(passes.Pass):
33+
"""Update All dialects on the call graph to a new set of dialects given to this pass.
34+
35+
Usage:
36+
pass_ = UpdateDialectsOnCallGraph(rule=rule, dialects=new_dialects)
37+
pass_(some_method)
38+
39+
Note: This pass does not update the dialects of the input method, but copies
40+
all other methods invoked within it before updating their dialects.
41+
42+
"""
43+
44+
fold_pass: passes.Fold = field(init=False)
45+
46+
def __post_init__(self):
47+
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
48+
49+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
50+
mt_map = {}
51+
52+
cg = CallGraph(mt)
53+
54+
all_methods = set(sum(map(tuple, cg.defs.values()), ()))
55+
for original_mt in all_methods:
56+
if original_mt is mt:
57+
new_mt = original_mt
58+
else:
59+
new_mt = original_mt.similar(self.dialects)
60+
mt_map[original_mt] = new_mt
61+
62+
result = RewriteResult()
63+
64+
for _, new_mt in mt_map.items():
65+
result = (
66+
rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code).join(result)
67+
)
68+
self.fold_pass(new_mt)
69+
70+
return result
71+
72+
73+
@dataclass
74+
class CallGraphPass(passes.Pass):
75+
"""Copy all functions in the call graph and apply a rule to each of them.
76+
77+
78+
Usage:
79+
rule = Walk(SomeRewriteRule())
80+
pass_ = CallGraphPass(rule=rule, dialects=...)
81+
pass_(some_method)
82+
83+
Note: This pass modifies the input method in place, but copies
84+
all methods invoked within it before applying the rule to them.
85+
86+
"""
87+
88+
rule: RewriteRule
89+
"""The rule to apply to each function in the call graph."""
90+
91+
fold_pass: passes.Fold = field(init=False)
92+
93+
def __post_init__(self):
94+
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
95+
96+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
97+
result = RewriteResult()
98+
mt_map = {}
99+
100+
cg = CallGraph(mt)
101+
102+
all_methods = set(cg.edges.keys())
103+
for original_mt in all_methods:
104+
if original_mt is mt:
105+
new_mt = original_mt
106+
else:
107+
new_mt = original_mt.similar()
108+
result = self.rule.rewrite(new_mt.code).join(result)
109+
mt_map[original_mt] = new_mt
110+
111+
if result.has_done_something:
112+
for _, new_mt in mt_map.items():
113+
rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code)
114+
self.fold_pass(new_mt)
115+
116+
return result

0 commit comments

Comments
 (0)