|
15 | 15 | from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) |
16 | 16 | from types import ModuleType |
17 | 17 | from triton._utils import list_list_flatten, list_list_unflatten |
| 18 | +from .._utils import find_paths_if, get_iterable_path, set_iterable_path |
18 | 19 |
|
19 | 20 |
|
20 | 21 | def mangle_ty(ty): |
@@ -189,6 +190,54 @@ def visit_Call(self, node: ast.Call) -> bool: |
189 | 190 | return self.visit(node.func) |
190 | 191 |
|
191 | 192 |
|
| 193 | +class ASTFunction: |
| 194 | + |
| 195 | + def __init__(self, ret_types, arg_types, constexprs, constants, attrs): |
| 196 | + self.ret_types = ret_types |
| 197 | + self.arg_types = arg_types |
| 198 | + self.constexprs = constexprs |
| 199 | + self.constants = constants |
| 200 | + self.attrs = attrs |
| 201 | + |
| 202 | + def serialize(self, builder: ir.builder): |
| 203 | + # fill up IR values in template |
| 204 | + # > build function |
| 205 | + is_val = lambda path, _: path not in self.constexprs and _ is not None |
| 206 | + val_paths = list(find_paths_if(self.arg_types, is_val)) |
| 207 | + arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths] |
| 208 | + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] |
| 209 | + return builder.get_function_ty(arg_types, ret_types) |
| 210 | + |
| 211 | + def deserialize(self, fn): |
| 212 | + # create "template" |
| 213 | + def make_template(val): |
| 214 | + if isinstance(val, (list, tuple, language.tuple_type)): |
| 215 | + return language.tuple([make_template(x) for x in val]) |
| 216 | + return language.constexpr(None) |
| 217 | + |
| 218 | + vals = make_template(self.arg_types) |
| 219 | + is_val = lambda path, _: path not in self.constexprs and _ is not None |
| 220 | + val_paths = list(find_paths_if(self.arg_types, is_val)) |
| 221 | + # > set attributes |
| 222 | + for attr_path, attr_specs in self.attrs.items(): |
| 223 | + for attr_name, attr_val in attr_specs: |
| 224 | + if attr_path in val_paths: |
| 225 | + fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val) |
| 226 | + for i, path in enumerate(val_paths): |
| 227 | + ty = get_iterable_path(self.arg_types, path) |
| 228 | + if isinstance(ty, nv_tma_desc_type): |
| 229 | + fn.set_arg_attr(i, "tt.nv_tma_desc", 1) |
| 230 | + # > add IR values to the template |
| 231 | + for i, path in enumerate(val_paths): |
| 232 | + ty = get_iterable_path(self.arg_types, path) |
| 233 | + set_iterable_path(vals, path, language.tensor(fn.args(i), ty)) |
| 234 | + # > add constexpr values to the template |
| 235 | + constants = self.constants | self.constexprs |
| 236 | + for path, val in constants.items(): |
| 237 | + set_iterable_path(vals, path, language.constexpr(val)) |
| 238 | + return vals |
| 239 | + |
| 240 | + |
192 | 241 | class CodeGenerator(ast.NodeVisitor): |
193 | 242 |
|
194 | 243 | def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, |
@@ -1083,16 +1132,15 @@ def visit_Assert(self, node) -> Any: |
1083 | 1132 | def call_JitFunction(self, fn: JITFunction, args, kwargs): |
1084 | 1133 | args = inspect.getcallargs(fn.fn, *args, **kwargs) |
1085 | 1134 | args = [args[name] for name in fn.arg_names] |
1086 | | - args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args] |
1087 | | - # generate function def |
1088 | | - attributes = {} |
1089 | | - constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] |
1090 | | - constants = {i: args[i] for i in constexprs} |
1091 | | - # generate call |
1092 | | - args = [None if i in constexprs else arg for i, arg in enumerate(args)] |
1093 | | - arg_vals = [arg.handle for arg in args if arg is not None] |
1094 | | - arg_types = [arg.type for arg in args if arg is not None] |
1095 | | - fn_name = mangle_fn(fn.__name__, arg_types, constants) |
| 1135 | + for i, arg in enumerate(args): |
| 1136 | + if isinstance(arg, (language.dtype, float, int, bool)): |
| 1137 | + args[i] = language.core.constexpr(arg) |
| 1138 | + args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x)) |
| 1139 | + args_cst = {path: get_iterable_path(args, path) for path in args_cst} |
| 1140 | + args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x)) |
| 1141 | + args_val = [get_iterable_path(args, path) for path in args_path] |
| 1142 | + # mangle |
| 1143 | + fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst) |
1096 | 1144 | # generate function def if necessary |
1097 | 1145 | if not self.module.has_function(fn_name): |
1098 | 1146 | prototype = language.function_type([], arg_types) |
|
0 commit comments