Skip to content

Commit ea006f2

Browse files
Revert "[FRONTEND] do not hold to references unnecessarily (#5402)"
This reverts commit acc25d9.
1 parent 2ad8203 commit ea006f2

File tree

2 files changed

+10
-117
lines changed

2 files changed

+10
-117
lines changed

python/triton/_utils.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -20,62 +20,3 @@ def list_list_unflatten(spec: List[int], flat: List[Any]) -> List[List[Any]]:
2020
idx += size
2121
assert idx == len(flat)
2222
return ret
23-
24-
25-
def get_iterable_path(iterable, path):
26-
from functools import reduce
27-
return reduce(lambda a, idx: a[idx], path, iterable)
28-
29-
30-
def set_iterable_path(iterable, path, val):
31-
prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
32-
prev[path[-1]] = val
33-
34-
35-
def find_paths_if(iterable, pred):
36-
from .language import core
37-
is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
38-
ret = dict()
39-
40-
def _impl(current, path):
41-
path = (path[0], ) if len(path) == 1 else tuple(path)
42-
if is_iterable(current):
43-
for idx, item in enumerate(current):
44-
_impl(item, path + (idx, ))
45-
elif pred(path, current):
46-
if len(path) == 1:
47-
ret[(path[0], )] = None
48-
else:
49-
ret[tuple(path)] = None
50-
51-
if is_iterable(iterable):
52-
_impl(iterable, [])
53-
elif pred(list(), iterable):
54-
ret = {tuple(): None}
55-
else:
56-
ret = dict()
57-
return list(ret.keys())
58-
59-
60-
def parse_list_string(s):
61-
s = s.strip()
62-
if s.startswith('[') and s.endswith(']'):
63-
s = s[1:-1]
64-
result = []
65-
current = ''
66-
depth = 0
67-
for c in s:
68-
if c == '[':
69-
depth += 1
70-
current += c
71-
elif c == ']':
72-
depth -= 1
73-
current += c
74-
elif c == ',' and depth == 0:
75-
result.append(current.strip())
76-
current = ''
77-
else:
78-
current += c
79-
if current.strip():
80-
result.append(current.strip())
81-
return result

python/triton/compiler/code_generator.py

Lines changed: 10 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
1616
from types import ModuleType
1717
from triton._utils import list_list_flatten, list_list_unflatten
18-
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
1918

2019

2120
def mangle_ty(ty):
@@ -190,54 +189,6 @@ def visit_Call(self, node: ast.Call) -> bool:
190189
return self.visit(node.func)
191190

192191

193-
class ASTFunction:
194-
195-
def __init__(self, ret_types, arg_types, constexprs, constants, attrs):
196-
self.ret_types = ret_types
197-
self.arg_types = arg_types
198-
self.constexprs = constexprs
199-
self.constants = constants
200-
self.attrs = attrs
201-
202-
def serialize(self, builder: ir.builder):
203-
# fill up IR values in template
204-
# > build function
205-
is_val = lambda path, _: path not in self.constexprs and _ is not None
206-
val_paths = list(find_paths_if(self.arg_types, is_val))
207-
arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths]
208-
ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types]
209-
return builder.get_function_ty(arg_types, ret_types)
210-
211-
def deserialize(self, fn):
212-
# create "template"
213-
def make_template(val):
214-
if isinstance(val, (list, tuple, language.tuple_type)):
215-
return language.tuple([make_template(x) for x in val])
216-
return language.constexpr(None)
217-
218-
vals = make_template(self.arg_types)
219-
is_val = lambda path, _: path not in self.constexprs and _ is not None
220-
val_paths = list(find_paths_if(self.arg_types, is_val))
221-
# > set attributes
222-
for attr_path, attr_specs in self.attrs.items():
223-
for attr_name, attr_val in attr_specs:
224-
if attr_path in val_paths:
225-
fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val)
226-
for i, path in enumerate(val_paths):
227-
ty = get_iterable_path(self.arg_types, path)
228-
if isinstance(ty, nv_tma_desc_type):
229-
fn.set_arg_attr(i, "tt.nv_tma_desc", 1)
230-
# > add IR values to the template
231-
for i, path in enumerate(val_paths):
232-
ty = get_iterable_path(self.arg_types, path)
233-
set_iterable_path(vals, path, language.tensor(fn.args(i), ty))
234-
# > add constexpr values to the template
235-
constants = self.constants | self.constexprs
236-
for path, val in constants.items():
237-
set_iterable_path(vals, path, language.constexpr(val))
238-
return vals
239-
240-
241192
class CodeGenerator(ast.NodeVisitor):
242193

243194
def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options,
@@ -1132,15 +1083,16 @@ def visit_Assert(self, node) -> Any:
11321083
def call_JitFunction(self, fn: JITFunction, args, kwargs):
11331084
args = inspect.getcallargs(fn.fn, *args, **kwargs)
11341085
args = [args[name] for name in fn.arg_names]
1135-
for i, arg in enumerate(args):
1136-
if isinstance(arg, (language.dtype, float, int, bool)):
1137-
args[i] = language.core.constexpr(arg)
1138-
args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x))
1139-
args_cst = {path: get_iterable_path(args, path) for path in args_cst}
1140-
args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
1141-
args_val = [get_iterable_path(args, path) for path in args_path]
1142-
# mangle
1143-
fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst)
1086+
args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args]
1087+
# generate function def
1088+
attributes = {}
1089+
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
1090+
constants = {i: args[i] for i in constexprs}
1091+
# generate call
1092+
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
1093+
arg_vals = [arg.handle for arg in args if arg is not None]
1094+
arg_types = [arg.type for arg in args if arg is not None]
1095+
fn_name = mangle_fn(fn.__name__, arg_types, constants)
11441096
# generate function def if necessary
11451097
if not self.module.has_function(fn_name):
11461098
prototype = language.function_type([], arg_types)

0 commit comments

Comments
 (0)