@@ -28,7 +28,7 @@ def check_identifier_legality(name, type):
28
28
return name
29
29
30
30
31
- def mangle_fn (name , arg_tys , constants ):
31
+ def mangle_fn (name , arg_tys , constants , caller_context ):
32
32
# doesn't mangle ret type, which must be a function of arg tys
33
33
mangled_arg_names = '_' .join ([ty .mangle () for ty in arg_tys ])
34
34
mangled_constants = '_' .join ([f'{ i } c{ repr (constants [i ])} ' for i in sorted (constants )])
@@ -37,6 +37,8 @@ def mangle_fn(name, arg_tys, constants):
37
37
# [ and ] are not allowed in LLVM identifiers
38
38
mangled_constants = mangled_constants .replace ('[' , '_' ).replace (']' , '_' )
39
39
ret = f'{ name } __{ mangled_arg_names } __{ mangled_constants } '
40
+ if caller_context is not None :
41
+ ret += caller_context .mangle ()
40
42
return ret
41
43
42
44
@@ -293,7 +295,7 @@ class CodeGenerator(ast.NodeVisitor):
293
295
294
296
def __init__ (self , context , prototype , gscope , function_name , jit_fn : JITFunction , options , codegen_fns , module_map ,
295
297
module = None , is_kernel = False , function_types : Optional [Dict ] = None , noinline = False ,
296
- file_name : Optional [str ] = None , begin_line = 0 ):
298
+ caller_context = None , file_name : Optional [str ] = None , begin_line = 0 ):
297
299
self .context = context
298
300
if jit_fn .is_gluon ():
299
301
from triton .experimental .gluon .language ._semantic import GluonSemantic
@@ -339,6 +341,7 @@ def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunctio
339
341
self .is_kernel = is_kernel
340
342
self .cur_node = None
341
343
self .noinline = noinline
344
+ self .caller_context = caller_context
342
345
self .scf_stack = []
343
346
self .ret_type = None
344
347
# SSA-construction
@@ -570,6 +573,8 @@ def visit_FunctionDef(self, node):
570
573
self .module .push_back (self .fn )
571
574
entry = self .fn .add_entry_block ()
572
575
arg_values = self .prototype .deserialize (self .fn )
576
+ if self .caller_context is not None :
577
+ self .caller_context .initialize_callee (self .fn , self .builder )
573
578
# bind arguments to symbols
574
579
for arg_name , arg_value in zip (arg_names , arg_values ):
575
580
self .set_value (arg_name , arg_value )
@@ -1190,7 +1195,7 @@ def visit_Assert(self, node) -> Any:
1190
1195
msg = self .visit (node .msg ) if node .msg is not None else ""
1191
1196
return language .core .device_assert (test , msg , _semantic = self .semantic )
1192
1197
1193
- def call_JitFunction (self , fn : JITFunction , args , kwargs ):
1198
+ def call_JitFunction (self , fn : JITFunction , args , kwargs , caller_context = None ):
1194
1199
args = inspect .getcallargs (fn .fn , * args , ** kwargs )
1195
1200
args = [args [name ] for name in fn .arg_names ]
1196
1201
for i , arg in enumerate (args ):
@@ -1201,7 +1206,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
1201
1206
args_path = find_paths_if (args , lambda _ , x : not _is_constexpr (x ))
1202
1207
args_val = [get_iterable_path (args , path ) for path in args_path ]
1203
1208
# mangle
1204
- fn_name = mangle_fn (get_full_name (fn ), [arg .type for arg in args_val ], args_cst )
1209
+ caller_context = caller_context or self .caller_context
1210
+ fn_name = mangle_fn (get_full_name (fn ), [arg .type for arg in args_val ], args_cst , caller_context )
1205
1211
# generate function def if necessary
1206
1212
if not self .module .has_function (fn_name ):
1207
1213
# If the callee is not set, we use the same debug setting as the caller
@@ -1216,7 +1222,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
1216
1222
function_name = fn_name , function_types = self .function_ret_types ,
1217
1223
noinline = fn .noinline , file_name = file_name , begin_line = begin_line ,
1218
1224
options = self .builder .options , codegen_fns = self .builder .codegen_fns ,
1219
- module_map = self .builder .module_map )
1225
+ module_map = self .builder .module_map , caller_context = caller_context )
1220
1226
try :
1221
1227
generator .visit (fn .parse ())
1222
1228
except Exception as e :
0 commit comments