|
15 | 15 | from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) |
16 | 16 | from types import ModuleType |
17 | 17 | from triton._utils import list_list_flatten, list_list_unflatten |
18 | | -from .._utils import find_paths_if, get_iterable_path, set_iterable_path |
19 | 18 |
|
20 | 19 |
|
21 | 20 | def mangle_ty(ty): |
@@ -190,54 +189,6 @@ def visit_Call(self, node: ast.Call) -> bool: |
190 | 189 | return self.visit(node.func) |
191 | 190 |
|
192 | 191 |
|
193 | | -class ASTFunction: |
194 | | - |
195 | | - def __init__(self, ret_types, arg_types, constexprs, constants, attrs): |
196 | | - self.ret_types = ret_types |
197 | | - self.arg_types = arg_types |
198 | | - self.constexprs = constexprs |
199 | | - self.constants = constants |
200 | | - self.attrs = attrs |
201 | | - |
202 | | - def serialize(self, builder: ir.builder): |
203 | | - # fill up IR values in template |
204 | | - # > build function |
205 | | - is_val = lambda path, _: path not in self.constexprs and _ is not None |
206 | | - val_paths = list(find_paths_if(self.arg_types, is_val)) |
207 | | - arg_types = [get_iterable_path(self.arg_types, path).to_ir(builder) for path in val_paths] |
208 | | - ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] |
209 | | - return builder.get_function_ty(arg_types, ret_types) |
210 | | - |
211 | | - def deserialize(self, fn): |
212 | | - # create "template" |
213 | | - def make_template(val): |
214 | | - if isinstance(val, (list, tuple, language.tuple_type)): |
215 | | - return language.tuple([make_template(x) for x in val]) |
216 | | - return language.constexpr(None) |
217 | | - |
218 | | - vals = make_template(self.arg_types) |
219 | | - is_val = lambda path, _: path not in self.constexprs and _ is not None |
220 | | - val_paths = list(find_paths_if(self.arg_types, is_val)) |
221 | | - # > set attributes |
222 | | - for attr_path, attr_specs in self.attrs.items(): |
223 | | - for attr_name, attr_val in attr_specs: |
224 | | - if attr_path in val_paths: |
225 | | - fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val) |
226 | | - for i, path in enumerate(val_paths): |
227 | | - ty = get_iterable_path(self.arg_types, path) |
228 | | - if isinstance(ty, nv_tma_desc_type): |
229 | | - fn.set_arg_attr(i, "tt.nv_tma_desc", 1) |
230 | | - # > add IR values to the template |
231 | | - for i, path in enumerate(val_paths): |
232 | | - ty = get_iterable_path(self.arg_types, path) |
233 | | - set_iterable_path(vals, path, language.tensor(fn.args(i), ty)) |
234 | | - # > add constexpr values to the template |
235 | | - constants = self.constants | self.constexprs |
236 | | - for path, val in constants.items(): |
237 | | - set_iterable_path(vals, path, language.constexpr(val)) |
238 | | - return vals |
239 | | - |
240 | | - |
241 | 192 | class CodeGenerator(ast.NodeVisitor): |
242 | 193 |
|
243 | 194 | def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, |
@@ -1132,15 +1083,16 @@ def visit_Assert(self, node) -> Any: |
1132 | 1083 | def call_JitFunction(self, fn: JITFunction, args, kwargs): |
1133 | 1084 | args = inspect.getcallargs(fn.fn, *args, **kwargs) |
1134 | 1085 | args = [args[name] for name in fn.arg_names] |
1135 | | - for i, arg in enumerate(args): |
1136 | | - if isinstance(arg, (language.dtype, float, int, bool)): |
1137 | | - args[i] = language.core.constexpr(arg) |
1138 | | - args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x)) |
1139 | | - args_cst = {path: get_iterable_path(args, path) for path in args_cst} |
1140 | | - args_path = find_paths_if(args, lambda _, x: not _is_constexpr(x)) |
1141 | | - args_val = [get_iterable_path(args, path) for path in args_path] |
1142 | | - # mangle |
1143 | | - fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst) |
| 1086 | + args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args] |
| 1087 | + # generate function def |
| 1088 | + attributes = {} |
| 1089 | + constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] |
| 1090 | + constants = {i: args[i] for i in constexprs} |
| 1091 | + # generate call |
| 1092 | + args = [None if i in constexprs else arg for i, arg in enumerate(args)] |
| 1093 | + arg_vals = [arg.handle for arg in args if arg is not None] |
| 1094 | + arg_types = [arg.type for arg in args if arg is not None] |
| 1095 | + fn_name = mangle_fn(fn.__name__, arg_types, constants) |
1144 | 1096 | # generate function def if necessary |
1145 | 1097 | if not self.module.has_function(fn_name): |
1146 | 1098 | prototype = language.function_type([], arg_types) |
|
0 commit comments