Skip to content

Commit a6a8f03

Browse files
authored
RFC: support overwriting existing method (#421)
this PR add support of overwriting existing method by directly updating the existing method object in (current Python not Kirin) scope if the name matches. We limit the compiler not to update existing method because this may cause unexpected behaviour because we won't actually update existing compiled code. However, after thinking a bit more, I think this is still ok because we can keep the same method object but update its body. The caveat is we need to retrigger the whole compilation of call the caller of this method, which is probably still ok? I didn't implement the [entire world age mechanism in Julia](https://d1wqtxts1xzle7.cloudfront.net/73563045/3428275-libre.pdf?1635141917=&response-content-disposition=inline%3B+filename%3DWorld_age_in_Julia_optimizing_method_dis.pdf&Expires=1749052521&Signature=OpvF9l3YO6ZRxHylvBFSMIH0O1Fl6AO2nyg1EZuD-5SWpZZw8E50yy7mgLcGMKhCwTWm6414MhLU9BwolYYnId-ichrBKTkxiDH3MPp050NFtXtmuGQ0Hxnc2wgeDgvKJPiNfVLydlPQLu7daEWw9uB4J2rAzb42m37YBOkfVvdjwvBHg8ysLncdDs-rc15mXKCmvsAYTFl7zRdaUxxhgphNYO~SOdNLB2QWxPpA8SVu3p-HGGitEzyjhtBJi9aQzH0xgBStX6~l2DIyMLEwlhYwjByrw0~moMhKWmOpRELyTfQdYeOxXD39YcztERPOFCE9dTHgzFv0fX-tMaDfkA__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA) but technically every time you define a new method it should trigger an eval -ish process that recompiles the related methods and thus creates a new world age. So I'd like to see if there are volunteers willing to take a second look and play with this PR a bit before moving forward. ## Motivation currently when supporting interactive environments, such as a notebook, it is quite cumbersome to use notebook as a dev environment like how one codes Python natively because we cannot update the code of an existing `Method`.
1 parent a49a961 commit a6a8f03

File tree

9 files changed

+143
-50
lines changed

9 files changed

+143
-50
lines changed

src/kirin/analysis/callgraph.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,42 +18,49 @@ class CallGraph(Printable):
1818
[`.print()`][kirin.print.printable.Printable.print] method.
1919
"""
2020

21-
"""Mapping from symbol names to methods."""
21+
edges: dict[ir.Method, set[ir.Method]] = field(default_factory=dict)
22+
"""Mapping from symbol names to edges (caller -> callee)."""
2223
backedges: dict[ir.Method, set[ir.Method]] = field(default_factory=dict)
23-
"""Mapping from symbol names to backedges."""
24+
"""Mapping from symbol names to backedges (callee -> caller)."""
2425

2526
def __init__(self, mt: ir.Method):
2627
self.defs = {}
28+
self.edges = {}
2729
self.backedges = {}
2830
self.__build(mt)
2931

3032
def __build(self, mt: ir.Method):
3133
for stmt in mt.callable_region.walk():
3234
if isinstance(stmt, func.Invoke):
33-
backedges = self.backedges.setdefault(stmt.callee, set())
34-
backedges.add(mt)
35+
edges = self.edges.setdefault(stmt.callee, set())
36+
edges.add(mt)
3537
self.__build(stmt.callee)
3638

39+
for caller in self.edges:
40+
for callee in self.edges[caller]:
41+
backedges = self.backedges.setdefault(callee, set())
42+
backedges.add(caller)
43+
3744
def get_neighbors(self, node: ir.Method) -> Iterable[ir.Method]:
3845
"""Get the neighbors of a node in the call graph."""
39-
return self.backedges.get(node, ())
46+
return self.edges.get(node, ())
4047

4148
def get_edges(self) -> Iterable[tuple[ir.Method, ir.Method]]:
4249
"""Get the edges of the call graph."""
43-
for node, neighbors in self.backedges.items():
50+
for node, neighbors in self.edges.items():
4451
for neighbor in neighbors:
4552
yield node, neighbor
4653

4754
def get_nodes(self) -> Iterable[ir.Method]:
4855
"""Get the nodes of the call graph."""
49-
return self.backedges.keys()
56+
return self.edges.keys()
5057

5158
def print_impl(self, printer: Printer) -> None:
52-
for idx, (caller, callee) in enumerate(self.backedges.items()):
59+
for idx, (caller, callee) in enumerate(self.edges.items()):
5360
printer.plain_print(caller)
5461
printer.plain_print(" -> ")
5562
printer.print_seq(
5663
callee, delim=", ", prefix="[", suffix="]", emit=printer.plain_print
5764
)
58-
if idx < len(self.backedges) - 1:
65+
if idx < len(self.edges) - 1:
5966
printer.print_newline()

src/kirin/dialects/func/stmts.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ def align_input_args(
3030
return tuple(inputs)
3131

3232

33+
class InvokeCall(ir.StaticCall["Invoke"]):
34+
35+
@classmethod
36+
def get_callee(cls, stmt: Invoke) -> ir.Method:
37+
return stmt.callee
38+
39+
3340
@statement(dialect=dialect)
3441
class Function(ir.Statement):
3542
name = "func"
@@ -270,7 +277,7 @@ def check_type(self) -> None:
270277
@statement(dialect=dialect)
271278
class Invoke(ir.Statement):
272279
name = "invoke"
273-
traits = frozenset({ir.MaybePure()})
280+
traits = frozenset({ir.MaybePure(), InvokeCall()})
274281
callee: ir.Method = info.attribute()
275282
inputs: tuple[ir.SSAValue, ...] = info.argument()
276283
result: ir.ResultValue = info.result()

src/kirin/ir/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
HasParent as HasParent,
3131
MaybePure as MaybePure,
3232
StmtTrait as StmtTrait,
33+
StaticCall as StaticCall,
3334
RegionGraph as RegionGraph,
3435
SymbolTable as SymbolTable,
3536
ConstantLike as ConstantLike,

src/kirin/ir/group.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -200,46 +200,76 @@ def wrapper(py_func: Callable) -> Method:
200200
raise ValueError("Cannot compile lambda functions")
201201

202202
lineno_offset, file = 0, ""
203+
mt = None
203204
if frame and frame.f_back is not None:
204205
call_site_frame = frame.f_back
205206
if py_func.__name__ in call_site_frame.f_locals:
206-
raise CompilerError(
207-
f"overwriting function definition of `{py_func.__name__}`"
208-
)
207+
mt = call_site_frame.f_locals[py_func.__name__]
208+
if not isinstance(mt, Method):
209+
raise CompilerError(
210+
f"`{py_func.__name__}` is already defined in the current scope and is not a Method."
211+
)
209212

210213
lineno_offset = call_site_frame.f_lineno - 1
211214
file = call_site_frame.f_code.co_filename
212215

213216
code = self.lowering.python_function(py_func, lineno_offset=lineno_offset)
214217
arg_names = ["#self#"] + inspect.getfullargspec(py_func).args
215-
mt = Method(
216-
dialects=self,
217-
code=code,
218-
nargs=len(arg_names),
219-
mod=inspect.getmodule(py_func),
220-
py_func=py_func,
221-
sym_name=py_func.__name__,
222-
arg_names=arg_names,
223-
file=file,
224-
lineno_begin=lineno_offset,
225-
)
218+
219+
if mt:
220+
mt.mod = inspect.getmodule(py_func)
221+
mt.dialects = self
222+
mt.code = code
223+
mt.py_func = py_func
224+
mt.nargs = len(arg_names)
225+
mt.arg_names = arg_names
226+
mt.sym_name = py_func.__name__
227+
mt.file = file
228+
mt.lineno_begin = lineno_offset
229+
mt.run_passes = self.run_pass
230+
mt.update_backedges() # update the callee
231+
self.recompile_callers(mt)
232+
else:
233+
mt = Method(
234+
dialects=self,
235+
code=code,
236+
nargs=len(arg_names),
237+
mod=inspect.getmodule(py_func),
238+
py_func=py_func,
239+
sym_name=py_func.__name__,
240+
arg_names=arg_names,
241+
file=file,
242+
lineno_begin=lineno_offset,
243+
)
244+
226245
if doc := inspect.getdoc(py_func):
227246
mt.__doc__ = doc
228247

229-
if self.run_pass is not None:
230-
try:
231-
self.run_pass(mt, *args, **options)
232-
except ValidationError as e:
233-
e.attach(mt)
234-
raise e
235-
248+
def run_pass(mt: Method) -> None:
249+
if self.run_pass is not None:
250+
try:
251+
self.run_pass(mt, *args, **options)
252+
except ValidationError as e:
253+
e.attach(mt)
254+
raise e
255+
256+
mt.run_passes = run_pass
257+
run_pass(mt)
236258
self.update_symbol_table(mt)
237259
return mt
238260

239261
if py_func is not None:
240262
return wrapper(py_func)
241263
return wrapper
242264

265+
def recompile_callers(self, method: Method) -> None:
266+
for caller in method.backedges:
267+
if caller.run_passes:
268+
caller.run_passes(caller)
269+
# propagate the changes to all callers
270+
caller.dialects.recompile_callers(caller)
271+
return
272+
243273
def update_symbol_table(self, method: Method) -> None:
244274
trait = method.code.get_trait(SymbolTable)
245275
if trait is None:

src/kirin/ir/method.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from kirin.print.printable import Printable
1111

1212
from .traits import (
13+
StaticCall,
1314
HasSignature,
1415
SymbolOpInterface,
1516
EntryPointInterface,
@@ -52,6 +53,9 @@ class Method(Printable, typing.Generic[Param, RetType]):
5253
inferred: bool = False
5354
"""if typeinfer has been run on this method
5455
"""
56+
backedges: set[Method] = field(init=False, repr=False)
57+
"""Cache for the backedges. (who calls this method)"""
58+
run_passes: typing.Callable[[Method], None] | None = field(init=False, repr=False)
5559

5660
def __init__(
5761
self,
@@ -96,6 +100,9 @@ def __init__(
96100
self.file = file
97101
self.lineno_begin = lineno_begin
98102
self.inferred = inferred
103+
self.backedges = set()
104+
self.update_backedges()
105+
self.run_passes = None
99106

100107
def __hash__(self) -> int:
101108
return id(self)
@@ -185,3 +192,13 @@ def verify_type(self) -> None:
185192
except ValidationError as e:
186193
e.attach(self)
187194
raise e
195+
196+
def update_backedges(self):
197+
"""Update the backedges of callee methods. (if they are static calls)"""
198+
for stmt in self.code.walk():
199+
trait = stmt.get_trait(StaticCall)
200+
if not trait:
201+
continue
202+
203+
callee = trait.get_callee(stmt)
204+
callee.backedges.add(self)

src/kirin/ir/traits/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
EntryPointInterface as EntryPointInterface,
3232
)
3333
from .callable import (
34+
StaticCall as StaticCall,
3435
HasSignature as HasSignature,
3536
CallableStmtInterface as CallableStmtInterface,
3637
)

src/kirin/ir/traits/callable.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from kirin.ir.traits.abc import StmtTrait
88

99
if TYPE_CHECKING:
10-
from kirin.ir import Region, Statement
10+
from kirin.ir import Method, Region, Statement
1111
from kirin.dialects.func.attrs import Signature
1212

1313
StmtType = TypeVar("StmtType", bound="Statement")
@@ -63,3 +63,12 @@ def verify(self, node: "Statement"):
6363
signature = self.get_signature(node)
6464
if not isinstance(signature, Signature):
6565
raise ValueError(f"{signature} is not a Signature attribute")
66+
67+
68+
class StaticCall(StmtTrait, ABC, Generic[StmtType]):
69+
70+
@classmethod
71+
@abstractmethod
72+
def get_callee(cls, stmt: StmtType) -> "Method":
73+
"""Returns the callee of the static call statement."""
74+
...

test/ir/test_duplicated.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

test/ir/test_group.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from kirin.ir import DialectGroup
2+
from kirin.prelude import basic
3+
from kirin.analysis import const
24
from kirin.dialects import cf, func
35
from kirin.dialects.py import base
46

@@ -45,3 +47,40 @@ def test_discard():
4547
assert "DialectGroup(" in target_b_repr
4648
assert base.dialect.name in target_b_repr
4749
assert cf.dialect.name in target_b_repr
50+
51+
52+
def test_overwrite():
53+
@basic
54+
def foo(x): # type: ignore
55+
return x * 2
56+
57+
@basic
58+
def main(x):
59+
return x + foo(x)
60+
61+
assert main(2) == 6
62+
63+
@basic
64+
def foo(x): # noqa: F811
65+
return x * 3
66+
67+
assert main(2) == 8
68+
69+
70+
def test_recompile():
71+
@basic
72+
def foo(x): # type: ignore
73+
return x * 2
74+
75+
@basic(fold=True)
76+
def main(x):
77+
return x + foo(x)
78+
79+
ret = main.callable_region.blocks[0].stmts.at(0).results[0]
80+
assert isinstance(ret.hints.get("const"), const.Unknown)
81+
82+
@basic
83+
def foo(x): # noqa: F811
84+
return 3
85+
86+
assert isinstance(ret.hints.get("const"), const.Value)

0 commit comments

Comments
 (0)