Skip to content

Commit acc25d9

Browse files
authored
[FRONTEND] do not hold to references unnecessarily (#5402)
1 parent 3cb3e69 commit acc25d9

File tree

2 files changed

+25
-21
lines changed

2 files changed

+25
-21
lines changed

python/triton/_utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ def list_list_unflatten(spec: List[int], flat: List[Any]) -> List[List[Any]]:
2222
return ret
2323

2424

25+
def get_iterable_path(iterable, path):
26+
from functools import reduce
27+
return reduce(lambda a, idx: a[idx], path, iterable)
28+
29+
30+
def set_iterable_path(iterable, path, val):
31+
prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
32+
prev[path[-1]] = val
33+
34+
2535
def find_paths_if(iterable, pred):
2636
from .language import core
2737
is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
@@ -34,17 +44,17 @@ def _impl(current, path):
3444
_impl(item, path + (idx, ))
3545
elif pred(path, current):
3646
if len(path) == 1:
37-
ret[(path[0], )] = current
47+
ret[(path[0], )] = None
3848
else:
39-
ret[tuple(path)] = current
49+
ret[tuple(path)] = None
4050

4151
if is_iterable(iterable):
4252
_impl(iterable, [])
4353
elif pred(list(), iterable):
44-
ret = {tuple(): iterable}
54+
ret = {tuple(): None}
4555
else:
4656
ret = dict()
47-
return ret
57+
return list(ret.keys())
4858

4959

5060
def parse_list_string(s):

python/triton/compiler/code_generator.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
1616
from types import ModuleType
1717
from triton._utils import list_list_flatten, list_list_unflatten
18-
from functools import reduce
19-
from .._utils import find_paths_if
18+
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
2019

2120

2221
def mangle_ty(ty):
@@ -195,13 +194,6 @@ def visit_Call(self, node: ast.Call) -> bool:
195194

196195
class ASTFunction:
197196

198-
def get_path(self, x, path):
199-
return reduce(lambda a, idx: a[idx], path, x)
200-
201-
def set_path(self, x, path, val):
202-
prev = x if len(path) == 1 else self.get_path(x, path[:-1])
203-
prev[path[-1]] = val
204-
205197
def __init__(self, ret_types, arg_types, constexprs, constants, attrs):
206198
self.ret_types = ret_types
207199
self.arg_types = arg_types
@@ -213,8 +205,8 @@ def serialize(self, builder: ir.builder):
213205
# fill up IR values in template
214206
# > build function
215207
is_val = lambda path, _: path not in self.constexprs and _ is not None
216-
val_paths = list(find_paths_if(self.arg_types, is_val).keys())
217-
arg_types = [self.get_path(self.arg_types, path).to_ir(builder) for path in val_paths]
208+
val_paths = list(find_paths_if(self.arg_types, is_val))
209+
arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths]
218210
ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types]
219211
return builder.get_function_ty(arg_types, ret_types)
220212

@@ -227,24 +219,24 @@ def make_template(val):
227219

228220
vals = make_template(self.arg_types)
229221
is_val = lambda path, _: path not in self.constexprs and _ is not None
230-
val_paths = list(find_paths_if(self.arg_types, is_val).keys())
222+
val_paths = list(find_paths_if(self.arg_types, is_val))
231223
# > set attributes
232224
for attr_path, attr_specs in self.attrs.items():
233225
for attr_name, attr_val in attr_specs:
234226
if attr_path in val_paths:
235227
fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val)
236228
for i, path in enumerate(val_paths):
237-
ty = self.get_path(self.arg_types, path)
229+
ty = get_iterable_path(self.arg_types, path)
238230
if isinstance(ty, nv_tma_desc_type):
239231
fn.set_arg_attr(i, "tt.nv_tma_desc", 1)
240232
# > add IR values to the template
241233
for i, path in enumerate(val_paths):
242-
ty = self.get_path(self.arg_types, path)
243-
self.set_path(vals, path, language.tensor(fn.args(i), ty))
234+
ty = get_iterable_path(self.arg_types, path)
235+
set_iterable_path(vals, path, language.tensor(fn.args(i), ty))
244236
# > add constexpr values to the template
245237
constants = self.constants | self.constexprs
246238
for path, val in constants.items():
247-
self.set_path(vals, path, language.constexpr(val))
239+
set_iterable_path(vals, path, language.constexpr(val))
248240
return vals
249241

250242

@@ -1139,7 +1131,9 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
11391131
if isinstance(arg, (language.dtype, float, int, bool)):
11401132
args[i] = language.core.constexpr(arg)
11411133
args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x))
1142-
args_val = find_paths_if(args, lambda _, x: not _is_constexpr(x)).values()
1134+
args_cst = {path: get_iterable_path(args, path) for path in args_cst}
1135+
args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
1136+
args_val = [get_iterable_path(args, path) for path in args_path]
11431137
# mangle
11441138
fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst)
11451139
# generate function def if necessary

0 commit comments

Comments
 (0)