1515from .errors import (CompilationError , CompileTimeAssertionFailure , UnsupportedLanguageConstruct )
1616from types import ModuleType
1717from triton ._utils import list_list_flatten , list_list_unflatten
18- from functools import reduce
19- from .._utils import find_paths_if
18+ from .._utils import find_paths_if , get_iterable_path , set_iterable_path
2019
2120
2221def mangle_ty (ty ):
@@ -195,13 +194,6 @@ def visit_Call(self, node: ast.Call) -> bool:
195194
196195class ASTFunction :
197196
198- def get_path (self , x , path ):
199- return reduce (lambda a , idx : a [idx ], path , x )
200-
201- def set_path (self , x , path , val ):
202- prev = x if len (path ) == 1 else self .get_path (x , path [:- 1 ])
203- prev [path [- 1 ]] = val
204-
205197 def __init__ (self , ret_types , arg_types , constexprs , constants , attrs ):
206198 self .ret_types = ret_types
207199 self .arg_types = arg_types
@@ -213,8 +205,8 @@ def serialize(self, builder: ir.builder):
213205 # fill up IR values in template
214206 # > build function
215207 is_val = lambda path , _ : path not in self .constexprs and _ is not None
216- val_paths = list (find_paths_if (self .arg_types , is_val ). keys () )
217- arg_types = [self . get_path (self .arg_types , path ).to_ir (builder ) for path in val_paths ]
208+ val_paths = list (find_paths_if (self .arg_types , is_val ))
209+ arg_types = [get_iterable_path (self .arg_types , path ).to_ir (builder ) for path in val_paths ]
218210 ret_types = [ret_type .to_ir (builder ) for ret_type in self .ret_types ]
219211 return builder .get_function_ty (arg_types , ret_types )
220212
@@ -227,24 +219,24 @@ def make_template(val):
227219
228220 vals = make_template (self .arg_types )
229221 is_val = lambda path , _ : path not in self .constexprs and _ is not None
230- val_paths = list (find_paths_if (self .arg_types , is_val ). keys () )
222+ val_paths = list (find_paths_if (self .arg_types , is_val ))
231223 # > set attributes
232224 for attr_path , attr_specs in self .attrs .items ():
233225 for attr_name , attr_val in attr_specs :
234226 if attr_path in val_paths :
235227 fn .set_arg_attr (val_paths .index (attr_path ), attr_name , attr_val )
236228 for i , path in enumerate (val_paths ):
237- ty = self . get_path (self .arg_types , path )
229+ ty = get_iterable_path (self .arg_types , path )
238230 if isinstance (ty , nv_tma_desc_type ):
239231 fn .set_arg_attr (i , "tt.nv_tma_desc" , 1 )
240232 # > add IR values to the template
241233 for i , path in enumerate (val_paths ):
242- ty = self . get_path (self .arg_types , path )
243- self . set_path (vals , path , language .tensor (fn .args (i ), ty ))
234+ ty = get_iterable_path (self .arg_types , path )
235+ set_iterable_path (vals , path , language .tensor (fn .args (i ), ty ))
244236 # > add constexpr values to the template
245237 constants = self .constants | self .constexprs
246238 for path , val in constants .items ():
247- self . set_path (vals , path , language .constexpr (val ))
239+ set_iterable_path (vals , path , language .constexpr (val ))
248240 return vals
249241
250242
@@ -1139,7 +1131,9 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
11391131 if isinstance (arg , (language .dtype , float , int , bool )):
11401132 args [i ] = language .core .constexpr (arg )
11411133 args_cst = find_paths_if (args , lambda _ , x : _is_constexpr (x ))
1142- args_val = find_paths_if (args , lambda _ , x : not _is_constexpr (x )).values ()
1134+ args_cst = {path : get_iterable_path (args , path ) for path in args_cst }
1135+ args_path = find_paths_if (args , lambda _ , x : not _is_constexpr (x ))
1136+ args_val = [get_iterable_path (args , path ) for path in args_path ]
11431137 # mangle
11441138 fn_name = mangle_fn (fn .__name__ , [arg .type for arg in args_val ], args_cst )
11451139 # generate function def if necessary
0 commit comments