1616# ideally we wouldn't need any runtime component
1717from ..runtime import JITFunction
1818from .._utils import find_paths_if , get_iterable_path , set_iterable_path
19+ from . import config
1920
2021from .errors import (CompilationError , CompileTimeAssertionFailure , UnsupportedLanguageConstruct )
2122
@@ -27,29 +28,9 @@ def check_identifier_legality(name, type):
2728 return name
2829
2930
30- def mangle_ty (ty ):
31- if ty .is_tuple ():
32- return 'T' + '_' .join (map (mangle_ty , ty .types )) + 'T'
33- if ty .is_ptr ():
34- return 'P' + mangle_ty (ty .element_ty )
35- if ty .is_int ():
36- SIGNED = language .dtype .SIGNEDNESS .SIGNED
37- prefix = 'i' if ty .int_signedness == SIGNED else 'u'
38- return prefix + str (ty .int_bitwidth )
39- if ty .is_floating ():
40- return str (ty )
41- if ty .is_block ():
42- elt = mangle_ty (ty .scalar )
43- shape = '_' .join (map (str , ty .shape ))
44- return f'{ elt } S{ shape } S'
45- if ty .is_void ():
46- return 'V'
47- raise TypeError (f'Unsupported type { ty } ' )
48-
49-
5031def mangle_fn (name , arg_tys , constants ):
5132 # doesn't mangle ret type, which must be a function of arg tys
52- mangled_arg_names = '_' .join ([mangle_ty ( ty ) for ty in arg_tys ])
33+ mangled_arg_names = '_' .join ([ty . mangle ( ) for ty in arg_tys ])
5334 mangled_constants = '_' .join ([f'{ i } c{ repr (constants [i ])} ' for i in sorted (constants )])
5435 mangled_constants = mangled_constants .replace ('.' , '_d_' )
5536 mangled_constants = mangled_constants .replace ("'" , '_sq_' )
@@ -71,8 +52,8 @@ def _is_constexpr(o: Any) -> bool:
7152 return o is None or isinstance (o , (constexpr , language .core .dtype ))
7253
7354
74- def _is_triton_scalar (o : Any ) -> bool :
75- return _is_triton_tensor (o ) and (not o .type .is_block () or o .type .numel = = 1 )
55+ def _is_non_scalar_tensor (o : Any ) -> bool :
56+ return _is_triton_tensor (o ) and (o .type .is_block () and o .type .numel ! = 1 )
7657
7758
7859def _is_list_like (o : Any ) -> bool :
@@ -82,7 +63,7 @@ def _is_list_like(o: Any) -> bool:
8263def _check_fn_args (node , fn , args ):
8364 if fn .noinline :
8465 for idx , arg in enumerate (args ):
85- if not _is_constexpr (arg ) and not _is_triton_scalar (arg ):
66+ if not _is_constexpr (arg ) and _is_non_scalar_tensor (arg ):
8667 raise UnsupportedLanguageConstruct (
8768 fn .src , node ,
8869 f'Function { fn .__name__ } is marked noinline, but was called with non-scalar argument { fn .arg_names [idx ]} :{ arg } '
@@ -241,26 +222,26 @@ def __init__(self, ret_types, arg_types, constants, attrs):
241222 self .constants = constants
242223 self .attrs = attrs
243224
244- def return_types_ir (self , builder : ir .builder ) :
245- ret_types = []
246- for ret_ty in self . ret_types :
247- if ret_ty is None :
225+ def flatten_ir_types (self , builder : ir .builder , types : List [ base_type ]) -> List [ ir . type ] :
226+ ir_types = []
227+ for ty in types :
228+ if ty is None :
248229 continue
249- ir_ty = ret_ty .to_ir (builder )
250- if isinstance (ir_ty , list ):
251- ret_types .extend (ir_ty )
252- else :
253- ret_types .append (ir_ty )
254- return ret_types
230+ ty ._flatten_ir_types (builder , ir_types )
231+ return ir_types
232+
233+ def return_types_ir (self , builder : ir .builder ) -> List [ir .type ]:
234+ return self .flatten_ir_types (builder , self .ret_types )
255235
256236 def serialize (self , builder : ir .builder ):
257237 # fill up IR values in template
258238 # > build function
259239 is_val = lambda path , _ : path not in self .constants and _ is not None
260240 val_paths = list (find_paths_if (self .arg_types , is_val ))
261- arg_types = [get_iterable_path (self .arg_types , path ).to_ir (builder ) for path in val_paths ]
262- ret_types = self .return_types_ir (builder )
263- return builder .get_function_ty (arg_types , ret_types )
241+ arg_types = [get_iterable_path (self .arg_types , path ) for path in val_paths ]
242+ arg_types_ir = self .flatten_ir_types (builder , arg_types )
243+ ret_types_ir = self .return_types_ir (builder )
244+ return builder .get_function_ty (arg_types_ir , ret_types_ir )
264245
265246 def deserialize (self , fn ):
266247 # create "template"
@@ -282,9 +263,12 @@ def make_template(ty):
282263 if isinstance (ty , nv_tma_desc_type ):
283264 fn .set_arg_attr (i , "tt.nv_tma_desc" , 1 )
284265 # > add IR values to the template
285- for i , path in enumerate (val_paths ):
266+ cursor = 0
267+ handles = [fn .args (i ) for i in range (fn .get_num_args ())]
268+ for path in val_paths :
286269 ty = get_iterable_path (self .arg_types , path )
287- set_iterable_path (vals , path , language .tensor (fn .args (i ), ty ))
270+ val , cursor = ty ._unflatten_ir (handles , cursor )
271+ set_iterable_path (vals , path , val )
288272 # > add constexpr values to the template
289273 constants = self .constants
290274 for path , val in constants .items ():
@@ -1218,14 +1202,16 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
12181202 generator .visit (fn .parse ())
12191203 except Exception as e :
12201204 # Wrap the error in the callee with the location of the call.
1205+ if config .front_end_debugging ():
1206+ raise
12211207 raise CompilationError (self .jit_fn .src , self .cur_node , None ) from e
12221208
12231209 callee_ret_type = generator .ret_type
12241210 self .function_ret_types [fn_name ] = callee_ret_type
12251211 else :
12261212 callee_ret_type = self .function_ret_types [fn_name ]
12271213 symbol = self .module .get_function (fn_name )
1228- args_val = [ arg . handle for arg in args_val ]
1214+ args_val = flatten_values_to_ir ( args_val )
12291215 call_op = self .builder .call (symbol , args_val )
12301216 if callee_ret_type == language .void :
12311217 return None
@@ -1256,6 +1242,8 @@ def visit_Call(self, node):
12561242 ret = language .tuple (ret )
12571243 return ret
12581244 except Exception as e :
1245+ if config .front_end_debugging ():
1246+ raise
12591247 # Normally when we raise a CompilationError, we raise it as
12601248 # `from None`, because the original fileline from the exception
12611249 # is not relevant (and often points into code_generator.py
@@ -1335,6 +1323,8 @@ def visit(self, node):
13351323 except CompilationError :
13361324 raise
13371325 except Exception as e :
1326+ if config .front_end_debugging ():
1327+ raise
13381328 # Wrap the error in a CompilationError which contains the source
13391329 # of the @jit function.
13401330 raise CompilationError (self .jit_fn .src , self .cur_node , repr (e )) from None
0 commit comments