diff --git a/.github/actions/setup-pytorch/action.yml b/.github/actions/setup-pytorch/action.yml index 16fba7afa3..e57eef8757 100644 --- a/.github/actions/setup-pytorch/action.yml +++ b/.github/actions/setup-pytorch/action.yml @@ -82,7 +82,7 @@ runs: uses: ./.github/actions/load env: # Increase this value to reset cache - CACHE_NUMBER: 11 + CACHE_NUMBER: 12 with: path: pytorch key: pytorch-$PYTORCH_CACHE_KEY-$CACHE_NUMBER diff --git a/.github/pins/pytorch-upstream.txt b/.github/pins/pytorch-upstream.txt index 4cf70cf66d..a3619a3159 100644 --- a/.github/pins/pytorch-upstream.txt +++ b/.github/pins/pytorch-upstream.txt @@ -1 +1 @@ -8321eec009c8c79145ebccd51fdfc336e5f8b848 +487873f7cafeb0fd390eaefe40496b804bceabbd diff --git a/python/test/unit/runtime/test_bindings.py b/python/test/unit/runtime/test_bindings.py index 11ce4f8fc1..206d132301 100644 --- a/python/test/unit/runtime/test_bindings.py +++ b/python/test/unit/runtime/test_bindings.py @@ -59,6 +59,8 @@ def walk_fn(op): torch.empty((32, 32), device=device), # out_ptr 16, # BLOCK_SIZE ] + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) src = triton.compiler.compiler.ASTSource( fn=kernel, signature={ @@ -69,12 +71,10 @@ def walk_fn(op): constants={kernel.arg_names[i]: arg for i, arg in enumerate(args) if not isinstance(arg, torch.Tensor)}, - attrs=kernel._get_config(*args, ), + attrs=backend.get_attrs_descriptor(args, kernel.params), ) context = triton._C.libtriton.ir.context() - target = triton.runtime.driver.active.get_current_target() - backend = triton.compiler.compiler.make_backend(target) options = backend.parse_options(dict()) codegen_fns = dict() module_map = backend.get_module_map() diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index c4afd1e0ed..0277792330 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -3,6 +3,7 @@ import triton import triton.language as tl +from triton.backends.compiler import AttrsDescriptor from triton.compiler import ASTSource target = triton.runtime.driver.active.get_current_target() @@ -25,7 +26,7 @@ def kernel_sub(a, b, o, N: tl.constexpr): def test_compile_in_subproc() -> None: - config = triton.compiler.AttrsDescriptor(tuple(range(4)), ()) + config = AttrsDescriptor.from_hints({i: 16 for i in range(4)}) multiprocessing.set_start_method('fork') proc = multiprocessing.Process(target=compile_fn, args=(config, )) proc.start() @@ -47,7 +48,7 @@ def kernel_dot(Z): def test_compile_in_forked_subproc(fresh_triton_cache) -> None: - config = triton.compiler.AttrsDescriptor(tuple(range(1)), ()) + config = AttrsDescriptor.from_hints({0: 16}) assert multiprocessing.get_start_method() == 'fork' proc = multiprocessing.Process(target=compile_fn_dot, args=(config, )) proc.start() @@ -86,7 +87,7 @@ def test_compile_in_forked_subproc_with_forced_gc(fresh_triton_cache) -> None: gc.disable() # stage 1.p - config = triton.compiler.AttrsDescriptor(tuple(range(1)), ()) + config = AttrsDescriptor.from_hints({0: 16}) compile_empty_kernel_with_gc(config) # stage 2.p diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index 3df0581582..11e1dc4cef 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -1,5 +1,6 @@ import os import re +import hashlib import subprocess from abc import ABCMeta, abstractmethod, abstractclassmethod @@ -8,6 +9,187 @@ from types import ModuleType +class AttrsDescriptor: + """ + This class handles compile-time properties for specific function parameters. + + Different backends can add more properties to the common ones. The class + contains two fields: + + `arg_properties`: a dictionary containing the different compile-time properties for different + parameters. I.e., the dictionary is a map from property names to parameter indices + { + "prop0": (0, 2, 3) + "prop1": (0, 4, 5) + } + Different backends might need different properties on those paraemters to enable + specific optimizations. The common compile time properties contained in this class + are : + - "tt.divisibility", i.e., is the given parameter divisible by 16 + - "tt.equal_to_1", i.e., is the given parameter an integer constant 1 + + `property_values`: a dictionary containing the value of the different compile-time properties, like: + { + "prop0": val0 + "prop1": val1 + } + + `constant_properties`: a set containing the properties that can be used to determine if a parameter is constant + + """ + __slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties') + + def __init__(self, params=None, values=None): + """ + Initialize the compile-time properties + + We can initialize the AttrsDescriptor class by passing the list of params + of the function and their `values`. The function will try to apply the properties + to the values and save the parameters in the `arg_properties` list. If we don't pass + either the `params` or the `values` we should initialize the class via an alternative method + (see `from_dict` or `from_hints`) + """ + # Default initialization + self.arg_properties = {} + self.property_values = {} + self.constant_properties = set() + + self._add_common_properties(params, values) + self._add_backend_properties(params, values) + self._init_slots() + + def _add_common_properties(self, params, values): + """ Add common compile-time properties """ + self.property_values["tt.divisibility"] = 16 + self.property_values["tt.equal_to"] = 1 + self.constant_properties.add("tt.equal_to") + + if (params is None) or (values is None): + return + + # Compile properties deduction + assert (len(params) == len(values)) + + # Divisibility property + self.arg_properties["tt.divisibility"] = [ + param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg) + and not param.do_not_specialize and not param.do_not_specialize_on_alignment + ] + + # Equal to 1 property + self.arg_properties["tt.equal_to"] = [ + param.num + for param, arg in zip(params, values) + if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize + ] + + def _add_backend_properties(self, params=None, values=None): + """ This method is for different subclasses to implement their own compile-time properties """ + pass + + def _init_slots(self): + """ Initialize the slots of this class """ + for name, val in self.arg_properties.items(): + setattr(self, name.removeprefix('tt.') + '_' + str(self.property_values[name]), val) + + def get_fn_attrs(self) -> Dict: + """ + Get the function attributes as a dictionary. + + The returned dictionary will look like : + { + "arg0" : [(prop_name00, val00), (prop_name01, val01), ...)]} + "arg1" : [(prop_name10, val10), (prop_name11, val11), ...)]} + } + """ + attrs = {} + for prop_name, arg_set in self.arg_properties.items(): + prop_val = self.property_values[prop_name] + for arg in arg_set: + attrs[arg] = attrs.get(arg, []) + [(prop_name, prop_val)] + return attrs + + def get_constants(self) -> Dict: + """ Return a mapping of constant parameters to their values """ + constants = {} + for prop_name in self.constant_properties: + for p in self.arg_properties.get(prop_name, []): + constants[p] = self.property_values[prop_name] + return constants + + def filter_out_constants(self): + """ Return the same object, without properties marked as constants""" + import copy + c = copy.deepcopy(self) + for prop_name in c.constant_properties: + c.arg_properties.pop(prop_name, None) + c.property_values.pop(prop_name, None) + c.constant_properties = {} + return c + + def hash(self): + values = [sorted(self.arg_properties.values())] + values += [sorted(self.property_values.values())] + values += [sorted(self.constant_properties)] + key = str(values) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def to_dict(self): + return self.arg_properties + + @staticmethod + def from_dict(data): + attrsDescriptor = AttrsDescriptor() + for prop_name, param_ids in data.items(): + attrsDescriptor.arg_properties[prop_name] = param_ids + attrsDescriptor._init_slots() + return attrsDescriptor + + @staticmethod + def from_hints(hints: list[tuple[int, int]]): + """ + Create the class from a set of hints that are passed in. + + Instead of deducing the properties from a list of paramaters and values, + the user can pass in a list of `hints=[(param_index, val)]` and if `val` + matches one of the values of the properties (e.g., `prop_val[prop0]`), + then we insert `param_index` into the correct list (e.g., in + `arg_properties[prop0]`) + """ + attrsDescriptor = AttrsDescriptor() + for prop_name, prop_val in attrsDescriptor.property_values.items(): + attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val] + attrsDescriptor._init_slots() + return attrsDescriptor + + @staticmethod + def is_divisible_by_16(x): + """ Return if the argument is a multiple of 16""" + if hasattr(x, "data_ptr"): + return x.data_ptr() % 16 == 0 + elif isinstance(x, int): + return x % 16 == 0 + if x is None: + return True + return False + + @staticmethod + def is_equal_to_1(x): + """ Return if the argument is a constant 1""" + return True if isinstance(x, int) and not isinstance(x, bool) and x == 1 else False + + @staticmethod + def get_property_key(val, align): + if align and AttrsDescriptor.is_divisible_by_16(val): + return "D" + if AttrsDescriptor.is_equal_to_1(val): + return "1" + return "N" + + def __repr__(self): + return f"AttrsDescriptor.from_dict({self.arg_properties})" + + @dataclass(frozen=True) class GPUTarget(object): # Target backend, e.g., cuda, hip @@ -79,6 +261,20 @@ def load_dialects(self, context): @abstractmethod def get_module_map(self) -> Dict[str, ModuleType]: """ - Return a map of interface modules to their device-specific implementations. + Return a map of interface modules to their device-specific implementations """ raise NotImplementedError + + def get_attrs_descriptor(self, params, args): + """ + Return an attribute descriptor: given a set of parameters and arguments + the descriptor stores a set of compile time properties that can improve code + generation. Different backends might benefit from different properties + """ + return AttrsDescriptor(params, args) + + def compute_spec_key(self, arg, align): + """ + Return the ascii key for a given argument with a given set of properties + """ + return AttrsDescriptor.get_property_key(arg, align) diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index ce0cfedfcd..bbe8c047c6 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,4 +1,4 @@ -from .compiler import CompiledKernel, ASTSource, compile, AttrsDescriptor, make_backend, LazyDict +from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict from .errors import CompilationError __all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"] diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index acc58b5952..9ab3f4bc0c 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -1265,7 +1265,7 @@ def kernel_suffix(signature, specialization): suffix += str(i) if i in specialization.equal_to_1: suffix += 'c' - if i in specialization.divisible_by_16: + if i in specialization.divisibility_16: suffix += 'd' return suffix @@ -1279,9 +1279,13 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map): gscope = fn.__globals__.copy() function_name = fn.repr(specialization) tys = list(specialization.signature.values()) - new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in attrs.equal_to_1} - new_attrs = {k: [("tt.divisibility", 16)] for k in attrs.divisible_by_16} + new_constants = attrs.get_constants() + for k in new_constants: + if k in tys and tys[k] == "i1" and new_constants[k] == 1: + new_constants[k] = True + new_attrs = attrs.filter_out_constants() + fn_attrs = new_attrs.get_fn_attrs() all_constants = constants.copy() all_constants.update(new_constants) arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] @@ -1289,7 +1293,7 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map): prototype = language.function_type([], arg_types) generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, - jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name, + jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map) generator.visit(fn.parse()) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 4844867ccf..77be3a2331 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -3,45 +3,19 @@ import json from .._C.libtriton import get_cache_invalidating_env_vars, ir from ..backends import backends -from ..backends.compiler import GPUTarget +from ..backends.compiler import GPUTarget, AttrsDescriptor from .. import __version__ from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager from ..runtime.driver import driver from ..tools.disasm import get_sass # TODO: this shouldn't be here -from dataclasses import dataclass from .code_generator import ast_to_ttir from pathlib import Path import re import functools import os - -@dataclass -class AttrsDescriptor: - divisible_by_16: set = None - equal_to_1: set = None - - def __post_init__(self): - if self.divisible_by_16 is None: - self.divisible_by_16 = set() - if self.equal_to_1 is None: - self.equal_to_1 = set() - - def to_dict(self): - return {'divisible_by_16': list(self.divisible_by_16), 'equal_to_1': list(self.equal_to_1)} - - @staticmethod - def from_dict(data): - return AttrsDescriptor(divisible_by_16=set(data.get('divisible_by_16', [])), - equal_to_1=set(data.get('equal_to_1', []))) - - def hash(self): - key = str([sorted(x) for x in self.__dict__.values()]) - return hashlib.sha256(key.encode("utf-8")).hexdigest() - - # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, # and any following whitespace # - (public\s+)? : optionally match the keyword public and any following whitespace diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index e417e11104..0842849ad9 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -342,7 +342,7 @@ def serialize_specialization_data(name, signature, constants, attrs, options, ke return serialized_obj -def create_function_from_signature(sig, kparams): +def create_function_from_signature(sig, kparams, backend): """ Equivalent to sig.bind followed by apply_defaults. This generates a native Python function (using exec) which can be memoized on a per-kernel @@ -401,7 +401,7 @@ def create_function_from_signature(sig, kparams): } func_namespace['mangle_type'] = mangle_type - func_namespace['compute_spec_key'] = compute_spec_key + func_namespace['compute_spec_key'] = backend.compute_spec_key # Execute the function string in func_namespace to create the function exec(func_body, func_namespace) @@ -445,7 +445,6 @@ class JITFunction(KernelInterface[T]): # Hook to signal that a kernel is done compiling and inspect compiled function. # cache_hook will always be called before compilation and compiled_hook after. compiled_hook = None - divisibility = 16 @staticmethod def _key_of(arg): @@ -467,42 +466,6 @@ def _key_of(arg): else: raise TypeError(f"Unsupported type {type(arg)} for {arg}") - @staticmethod - def _spec_of(arg): - if hasattr(arg, "data_ptr"): - return arg.data_ptr() % JITFunction.divisibility == 0 - elif isinstance(arg, int): - return (arg % 16 == 0, arg == 1) - return (arg is None, ) - - def _get_config(self, *args): - from ..compiler import AttrsDescriptor - - def is_divisible_by_16(x): - if hasattr(x, "data_ptr"): - return x.data_ptr() % JITFunction.divisibility == 0 - elif isinstance(x, int): - return x % JITFunction.divisibility == 0 - if x is None: - return True - return False - - divisible_by_16 = { - param.num - for param, arg in zip(self.params, args) - if is_divisible_by_16(arg) and not param.do_not_specialize and not param.do_not_specialize_on_alignment - } - equal_to_1 = { - param.num - for param, arg in zip(self.params, args) - if isinstance(arg, int) and not isinstance(arg, bool) and arg == 1 and not param.do_not_specialize - } - # folded equal_to_1 and None - # TODO: method to collect all folded args - return AttrsDescriptor(tuple(divisible_by_16), tuple(equal_to_1)) - # return _triton.code_gen.instance_descriptor(divisible_by_16, - # equal_to_1) - @staticmethod def _type_of(key, is_const=False): # `None` is nullptr. Implicitly convert to *i8. @@ -581,7 +544,7 @@ def add_pre_run_hook(self, hook): assert callable(hook) self.pre_run_hooks.append(hook) - def create_binder(self): + def create_binder(self, backend): """ Precompute as much as possible. """ @@ -590,7 +553,7 @@ def create_binder(self): self.compile = compile self.ASTSource = ASTSource self.make_backend = make_backend - self.binder = create_function_from_signature(self.signature, self.params) + self.binder = create_function_from_signature(self.signature, self.params, backend) self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr] self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr] self.specialised_indices = [ @@ -601,15 +564,18 @@ def run(self, *args, grid, warmup, **kwargs): kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1" # parse options + from ..compiler import make_backend device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) + target = driver.active.get_current_target() + backend = make_backend(target) # Execute pre run hooks with args and kwargs for hook in self.pre_run_hooks: hook(*args, **kwargs) if self.binder is None: - self.create_binder() + self.create_binder(backend) bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) @@ -619,8 +585,6 @@ def run(self, *args, grid, warmup, **kwargs): if kernel is None: # Kernel is not cached; we have to compile. - target = driver.active.get_current_target() - backend = self.make_backend(target) options = backend.parse_options(kwargs) # deprecated arguments @@ -641,11 +605,12 @@ def run(self, *args, grid, warmup, **kwargs): sigvals = sig_and_spec[:len(sigkeys)] signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} - configs = (self._get_config(*bound_vals), ) + configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) + constant_params = configs[0].get_constants() constants = { p.name: v for (v, p) in zip(bound_vals, self.params) - if p.is_constexpr or p.num in configs[0].equal_to_1 or v is None + if p.is_constexpr or (p.num in constant_params) or v is None } for i, arg in constants.items(): if callable(arg): @@ -763,7 +728,8 @@ def warmup(self, *args, grid, **kwargs): return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) def preload(self, specialization_data): - from ..compiler import AttrsDescriptor, compile, ASTSource + from ..compiler import compile, ASTSource + from triton.backends.compiler import AttrsDescriptor import json import triton.language as tl device = driver.active.get_current_device() diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 1e9697cc82..443341fa0d 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -7,6 +7,7 @@ from typing import List import triton +import triton.backends from triton.compiler.code_generator import kernel_suffix from triton.backends.nvidia.driver import ty_to_cpp @@ -106,11 +107,9 @@ def constexpr(s): # compile ast into cubin for h in hints.values(): assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" - divisible_by_16 = [i for i, h in hints.items() if h == 16] - equal_to_1 = [i for i, h in hints.items() if h == 1] - attrs = triton.compiler.AttrsDescriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) - for i in equal_to_1: - constants.update({kernel.arg_names[i]: 1}) + attrs = triton.backends.compiler.AttrsDescriptor.from_hints(hints) + for p, v in attrs.get_constants().items(): + constants.update({kernel.arg_names[p]: v}) src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} ccinfo = triton.compile(src, options=opts) @@ -124,7 +123,7 @@ def constexpr(s): arg_types.append(signature[arg_name]) arg_names_not_1.append(arg_name) arg_types_not_1.append(signature[arg_name]) - elif i in equal_to_1: + elif i in attrs.equal_to_1: arg_names.append(arg_name) arg_types.append(signature[arg_name]) diff --git a/scripts/patch-pytorch.sh b/scripts/patch-pytorch.sh index c9dbf931ca..7ceb300b74 100755 --- a/scripts/patch-pytorch.sh +++ b/scripts/patch-pytorch.sh @@ -17,3 +17,4 @@ cd "$REPO_ROOT" curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply - curl -sSL https://github.com/pytorch/pytorch/pull/126456.diff | git apply - +curl -sSL https://github.com/pytorch/pytorch/pull/138390.diff | git apply -