Skip to content

Commit 974c7de

Browse files
authored
add callgraph (#103)
this add an analysis constructs callgraph
1 parent 5e5d2d7 commit 974c7de

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

src/kirin/analysis/callgraph.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from dataclasses import dataclass, field
2+
from typing import Iterable
3+
4+
from kirin import ir
5+
from kirin.dialects import func
6+
from kirin.print import Printable
7+
from kirin.print.printer import Printer
8+
9+
10+
@dataclass
11+
class CallGraph(Printable):
12+
defs: dict[str, ir.Method] = field(default_factory=dict)
13+
backedges: dict[str, set[str]] = field(default_factory=dict)
14+
15+
def __init__(self, mt: ir.Method):
16+
self.defs = {}
17+
self.backedges = {}
18+
self.__build(mt)
19+
20+
def __build(self, mt: ir.Method):
21+
self.defs[mt.sym_name] = mt
22+
for stmt in mt.callable_region.walk():
23+
if isinstance(stmt, func.Invoke):
24+
backedges = self.backedges.setdefault(stmt.callee.sym_name, set())
25+
backedges.add(mt.sym_name)
26+
self.__build(stmt.callee)
27+
28+
def get_neighbors(self, node: str) -> Iterable[str]:
29+
return self.backedges.get(node, ())
30+
31+
def get_edges(self) -> Iterable[tuple[str, str]]:
32+
for node, neighbors in self.backedges.items():
33+
for neighbor in neighbors:
34+
yield node, neighbor
35+
36+
def get_nodes(self) -> Iterable[str]:
37+
return self.defs.keys()
38+
39+
def print_impl(self, printer: Printer) -> None:
40+
for idx, (caller, callee) in enumerate(self.backedges.items()):
41+
printer.plain_print(caller)
42+
printer.plain_print(" -> ")
43+
printer.print_seq(
44+
callee, delim=", ", prefix="[", suffix="]", emit=printer.plain_print
45+
)
46+
if idx < len(self.backedges) - 1:
47+
printer.print_newline()

test/analysis/test_callgraph.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from kirin.analysis.callgraph import CallGraph
2+
from kirin.prelude import basic_no_opt
3+
4+
5+
@basic_no_opt
6+
def abc(a, b):
7+
return a + b
8+
9+
10+
@basic_no_opt
11+
def bcd(a, b):
12+
return a - b
13+
14+
15+
@basic_no_opt
16+
def cde(a, b):
17+
return abc(a, b) + bcd(a, b)
18+
19+
20+
@basic_no_opt
21+
def defg(a, b):
22+
return cde(a, b) + abc(a, b)
23+
24+
25+
@basic_no_opt
26+
def efg(a, b):
27+
return defg(a, b) + bcd(a, b)
28+
29+
30+
def test_callgraph():
31+
graph = CallGraph(efg)
32+
graph.print()
33+
assert "cde" in graph.get_neighbors("abc")
34+
assert "defg" in graph.get_neighbors("abc")
35+
assert "cde" in graph.get_neighbors("abc")
36+
assert "defg" in graph.get_neighbors("abc")

0 commit comments

Comments
 (0)