Skip to content

Commit 2ad8203

Browse files
Merge commit 'acc25d91fba850c18c099e7e577962ba56bdd06c'
2 parents 03cb38c + acc25d9 commit 2ad8203

File tree

2 files changed

+117
-10
lines changed

2 files changed

+117
-10
lines changed

python/triton/_utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,62 @@ def list_list_unflatten(spec: List[int], flat: List[Any]) -> List[List[Any]]:
2020
idx += size
2121
assert idx == len(flat)
2222
return ret
23+
24+
25+
def get_iterable_path(iterable, path):
26+
from functools import reduce
27+
return reduce(lambda a, idx: a[idx], path, iterable)
28+
29+
30+
def set_iterable_path(iterable, path, val):
31+
prev = iterable if len(path) == 1 else get_iterable_path(iterable, path[:-1])
32+
prev[path[-1]] = val
33+
34+
35+
def find_paths_if(iterable, pred):
36+
from .language import core
37+
is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type))
38+
ret = dict()
39+
40+
def _impl(current, path):
41+
path = (path[0], ) if len(path) == 1 else tuple(path)
42+
if is_iterable(current):
43+
for idx, item in enumerate(current):
44+
_impl(item, path + (idx, ))
45+
elif pred(path, current):
46+
if len(path) == 1:
47+
ret[(path[0], )] = None
48+
else:
49+
ret[tuple(path)] = None
50+
51+
if is_iterable(iterable):
52+
_impl(iterable, [])
53+
elif pred(list(), iterable):
54+
ret = {tuple(): None}
55+
else:
56+
ret = dict()
57+
return list(ret.keys())
58+
59+
60+
def parse_list_string(s):
61+
s = s.strip()
62+
if s.startswith('[') and s.endswith(']'):
63+
s = s[1:-1]
64+
result = []
65+
current = ''
66+
depth = 0
67+
for c in s:
68+
if c == '[':
69+
depth += 1
70+
current += c
71+
elif c == ']':
72+
depth -= 1
73+
current += c
74+
elif c == ',' and depth == 0:
75+
result.append(current.strip())
76+
current = ''
77+
else:
78+
current += c
79+
if current.strip():
80+
result.append(current.strip())
81+
return result

python/triton/compiler/code_generator.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
1616
from types import ModuleType
1717
from triton._utils import list_list_flatten, list_list_unflatten
18+
from .._utils import find_paths_if, get_iterable_path, set_iterable_path
1819

1920

2021
def mangle_ty(ty):
@@ -189,6 +190,54 @@ def visit_Call(self, node: ast.Call) -> bool:
189190
return self.visit(node.func)
190191

191192

193+
class ASTFunction:
194+
195+
def __init__(self, ret_types, arg_types, constexprs, constants, attrs):
196+
self.ret_types = ret_types
197+
self.arg_types = arg_types
198+
self.constexprs = constexprs
199+
self.constants = constants
200+
self.attrs = attrs
201+
202+
def serialize(self, builder: ir.builder):
203+
# fill up IR values in template
204+
# > build function
205+
is_val = lambda path, _: path not in self.constexprs and _ is not None
206+
val_paths = list(find_paths_if(self.arg_types, is_val))
207+
arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths]
208+
ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types]
209+
return builder.get_function_ty(arg_types, ret_types)
210+
211+
def deserialize(self, fn):
212+
# create "template"
213+
def make_template(val):
214+
if isinstance(val, (list, tuple, language.tuple_type)):
215+
return language.tuple([make_template(x) for x in val])
216+
return language.constexpr(None)
217+
218+
vals = make_template(self.arg_types)
219+
is_val = lambda path, _: path not in self.constexprs and _ is not None
220+
val_paths = list(find_paths_if(self.arg_types, is_val))
221+
# > set attributes
222+
for attr_path, attr_specs in self.attrs.items():
223+
for attr_name, attr_val in attr_specs:
224+
if attr_path in val_paths:
225+
fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val)
226+
for i, path in enumerate(val_paths):
227+
ty = get_iterable_path(self.arg_types, path)
228+
if isinstance(ty, nv_tma_desc_type):
229+
fn.set_arg_attr(i, "tt.nv_tma_desc", 1)
230+
# > add IR values to the template
231+
for i, path in enumerate(val_paths):
232+
ty = get_iterable_path(self.arg_types, path)
233+
set_iterable_path(vals, path, language.tensor(fn.args(i), ty))
234+
# > add constexpr values to the template
235+
constants = self.constants | self.constexprs
236+
for path, val in constants.items():
237+
set_iterable_path(vals, path, language.constexpr(val))
238+
return vals
239+
240+
192241
class CodeGenerator(ast.NodeVisitor):
193242

194243
def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options,
@@ -1083,16 +1132,15 @@ def visit_Assert(self, node) -> Any:
10831132
def call_JitFunction(self, fn: JITFunction, args, kwargs):
10841133
args = inspect.getcallargs(fn.fn, *args, **kwargs)
10851134
args = [args[name] for name in fn.arg_names]
1086-
args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args]
1087-
# generate function def
1088-
attributes = {}
1089-
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
1090-
constants = {i: args[i] for i in constexprs}
1091-
# generate call
1092-
args = [None if i in constexprs else arg for i, arg in enumerate(args)]
1093-
arg_vals = [arg.handle for arg in args if arg is not None]
1094-
arg_types = [arg.type for arg in args if arg is not None]
1095-
fn_name = mangle_fn(fn.__name__, arg_types, constants)
1135+
for i, arg in enumerate(args):
1136+
if isinstance(arg, (language.dtype, float, int, bool)):
1137+
args[i] = language.core.constexpr(arg)
1138+
args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x))
1139+
args_cst = {path: get_iterable_path(args, path) for path in args_cst}
1140+
args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x))
1141+
args_val = [get_iterable_path(args, path) for path in args_path]
1142+
# mangle
1143+
fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst)
10961144
# generate function def if necessary
10971145
if not self.module.has_function(fn_name):
10981146
prototype = language.function_type([], arg_types)

0 commit comments

Comments
 (0)