@@ -200,46 +200,76 @@ def wrapper(py_func: Callable) -> Method:
200200 raise ValueError ("Cannot compile lambda functions" )
201201
202202 lineno_offset , file = 0 , ""
203+ mt = None
203204 if frame and frame .f_back is not None :
204205 call_site_frame = frame .f_back
205206 if py_func .__name__ in call_site_frame .f_locals :
206- raise CompilerError (
207- f"overwriting function definition of `{ py_func .__name__ } `"
208- )
207+ mt = call_site_frame .f_locals [py_func .__name__ ]
208+ if not isinstance (mt , Method ):
209+ raise CompilerError (
210+ f"`{ py_func .__name__ } ` is already defined in the current scope and is not a Method."
211+ )
209212
210213 lineno_offset = call_site_frame .f_lineno - 1
211214 file = call_site_frame .f_code .co_filename
212215
213216 code = self .lowering .python_function (py_func , lineno_offset = lineno_offset )
214217 arg_names = ["#self#" ] + inspect .getfullargspec (py_func ).args
215- mt = Method (
216- dialects = self ,
217- code = code ,
218- nargs = len (arg_names ),
219- mod = inspect .getmodule (py_func ),
220- py_func = py_func ,
221- sym_name = py_func .__name__ ,
222- arg_names = arg_names ,
223- file = file ,
224- lineno_begin = lineno_offset ,
225- )
218+
219+ if mt :
220+ mt .mod = inspect .getmodule (py_func )
221+ mt .dialects = self
222+ mt .code = code
223+ mt .py_func = py_func
224+ mt .nargs = len (arg_names )
225+ mt .arg_names = arg_names
226+ mt .sym_name = py_func .__name__
227+ mt .file = file
228+ mt .lineno_begin = lineno_offset
229+ mt .run_passes = self .run_pass
230+ mt .update_backedges () # update the callee
231+ self .recompile_callers (mt )
232+ else :
233+ mt = Method (
234+ dialects = self ,
235+ code = code ,
236+ nargs = len (arg_names ),
237+ mod = inspect .getmodule (py_func ),
238+ py_func = py_func ,
239+ sym_name = py_func .__name__ ,
240+ arg_names = arg_names ,
241+ file = file ,
242+ lineno_begin = lineno_offset ,
243+ )
244+
226245 if doc := inspect .getdoc (py_func ):
227246 mt .__doc__ = doc
228247
229- if self .run_pass is not None :
230- try :
231- self .run_pass (mt , * args , ** options )
232- except ValidationError as e :
233- e .attach (mt )
234- raise e
235-
248+ def run_pass (mt : Method ) -> None :
249+ if self .run_pass is not None :
250+ try :
251+ self .run_pass (mt , * args , ** options )
252+ except ValidationError as e :
253+ e .attach (mt )
254+ raise e
255+
256+ mt .run_passes = run_pass
257+ run_pass (mt )
236258 self .update_symbol_table (mt )
237259 return mt
238260
239261 if py_func is not None :
240262 return wrapper (py_func )
241263 return wrapper
242264
265+ def recompile_callers (self , method : Method ) -> None :
266+ for caller in method .backedges :
267+ if caller .run_passes :
268+ caller .run_passes (caller )
269+ # propagate the changes to all callers
270+ caller .dialects .recompile_callers (caller )
271+ return
272+
243273 def update_symbol_table (self , method : Method ) -> None :
244274 trait = method .code .get_trait (SymbolTable )
245275 if trait is None :
0 commit comments