diff --git a/.github/pins/pytorch-upstream.txt b/.github/pins/pytorch-upstream.txt index 34c584cb51..c2ce8b1a5f 100644 --- a/.github/pins/pytorch-upstream.txt +++ b/.github/pins/pytorch-upstream.txt @@ -1 +1 @@ -0efa590d435d2b4aefcbad9014dd5fa75dcf8405 +33dce10ece5b38aa0ab76739b658cd980a6e3d8f diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index a0084e0be9..a45cb3f888 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -9,6 +9,7 @@ import triton import triton.language as tl from triton.runtime.jit import JITFunction +from triton._internal_testing import is_hip @triton.jit @@ -572,3 +573,29 @@ def compiled_hook(*args, **kwargs): assert specialization_data is not None and specialization_data_compiled == specialization_data assert is_warmup is True assert key in kernel_add.cache[getattr(torch, device).current_device()] + + +@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip()) +def test_within_2gb(device, fresh_triton_cache) -> None: + + @triton.jit + def kernel_add(a): + tl.load(a) + + # This is the attribute we want to test + pointer_range_32 = None + + def cache_hook(*args, **kwargs): + nonlocal pointer_range_32 + pointer_range_32 = kwargs["compile"]["configs"][0].pointer_range_32 + + JITFunction.cache_hook = cache_hook + # In warmup we assume that the pointer range is 32 bits + kernel_add.warmup(torch.float32, grid=(1, )) + assert pointer_range_32 == [0] + # Torch tensor > 2GB + kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device)) + assert len(pointer_range_32) == 0 + # Torch tensor <= 2GB + kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device)) + assert pointer_range_32 == [0] diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index 11e1dc4cef..92486cdc66 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -8,7 +8,21 @@ from typing import Dict, Union from types import ModuleType +# Table that associates strings to AttrsDescriptor (sub)classes. +# In this way we can dynamically select the correct class +# constructor +_descriptor_table = {} + +def register_descriptor(cls): + """ + Register a descriptor into the descriptor table + """ + _descriptor_table[cls.__name__] = cls + return cls + + +@register_descriptor class AttrsDescriptor: """ This class handles compile-time properties for specific function parameters. @@ -135,18 +149,28 @@ def hash(self): return hashlib.sha256(key.encode("utf-8")).hexdigest() def to_dict(self): - return self.arg_properties + """ + Store the fields of this class in a serializable dictionary + """ + # We need to only store the `arg_properties` field. To initialize the + # other fields we relay on the class type. We store it as a string in + # the dictionary so that we can use it to invoke the appropriate + # (sub)class constructor in the `from_dict` method. + return {"arg_properties": self.arg_properties, "cls": type(self).__name__} @staticmethod def from_dict(data): - attrsDescriptor = AttrsDescriptor() - for prop_name, param_ids in data.items(): - attrsDescriptor.arg_properties[prop_name] = param_ids - attrsDescriptor._init_slots() - return attrsDescriptor - - @staticmethod - def from_hints(hints: list[tuple[int, int]]): + """ + Create the object from a serializable dictionary + """ + attrs_descriptor = _descriptor_table[data["cls"]]() + for prop_name, param_ids in data["arg_properties"].items(): + attrs_descriptor.arg_properties[prop_name] = param_ids + attrs_descriptor._init_slots() + return attrs_descriptor + + @classmethod + def from_hints(cls, hints: list[tuple[int, int]]): """ Create the class from a set of hints that are passed in. @@ -156,11 +180,11 @@ def from_hints(hints: list[tuple[int, int]]): then we insert `param_index` into the correct list (e.g., in `arg_properties[prop0]`) """ - attrsDescriptor = AttrsDescriptor() - for prop_name, prop_val in attrsDescriptor.property_values.items(): - attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val] - attrsDescriptor._init_slots() - return attrsDescriptor + attrs_descriptor = cls() + for prop_name, prop_val in attrs_descriptor.property_values.items(): + attrs_descriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val] + attrs_descriptor._init_slots() + return attrs_descriptor @staticmethod def is_divisible_by_16(x): @@ -187,7 +211,7 @@ def get_property_key(val, align): return "N" def __repr__(self): - return f"AttrsDescriptor.from_dict({self.arg_properties})" + return f"AttrsDescriptor.from_dict({self.to_dict()!r})" @dataclass(frozen=True) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 0842849ad9..45178a40bb 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -879,6 +879,10 @@ def __init__(self, dtype): def data_ptr(): return 0 # optimistically assumes multiple of 16 + @staticmethod + def ptr_range(): + return 0 # optimistically assumes 32 bit pointer range + class TensorWrapper: diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 162695c2d9..390d1c83e6 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -1,4 +1,4 @@ -from triton.backends.compiler import BaseBackend, GPUTarget +from triton.backends.compiler import BaseBackend, GPUTarget, AttrsDescriptor, register_descriptor from triton._C.libtriton import ir, passes, llvm, amd from dataclasses import dataclass from typing import Any, Dict, Tuple @@ -72,6 +72,44 @@ def hash(self): return hashlib.sha256(key.encode("utf-8")).hexdigest() +@register_descriptor +class HIPAttrsDescriptor(AttrsDescriptor): + # This property asserts if the underlying storage area of a given pointer + # can be resepresented as a 32 bit integer. When this is true, we can be + # sure that all indices into the tensor behind that pointer can use 32-bit + # indexing. That opens the door for the AMD backend to use buffer load/store + # instrinsics, which requires this property. Buffer load/store intrinsics + # gives direct out-of-bound support and simplifies index calculation for + # lower register pressure. + __slots__ = ("pointer_range_32") + + def _add_backend_properties(self, params=None, values=None): + self.property_values["tt.pointer_range"] = 32 + if params is None or values is None: + return + + self.arg_properties["tt.pointer_range"] = [ + param.num for param, arg in zip(params, values) if HIPAttrsDescriptor.is_within2gb(arg) + and not param.do_not_specialize and not param.do_not_specialize_on_alignment + ] + + @staticmethod + def is_within2gb(arg): + if hasattr(arg, "ptr_range"): + return arg.ptr_range() <= 2**31 - 1 + if "torch.Tensor" in str(type(arg)) and hasattr(arg, "untyped_storage"): + # Please note that 2**31-1 is the max int32 positive limit + return arg.untyped_storage().size() <= 2**31 - 1 + return False + + @staticmethod + def get_property_key(val, align): + generic_key = AttrsDescriptor.get_property_key(val, align) + hip_key = "S" if HIPAttrsDescriptor.is_within2gb(val) else "N" + key = (generic_key + hip_key).replace("N", "") + return key if key else "N" + + class HIPBackend(BaseBackend): @staticmethod @@ -118,6 +156,13 @@ def get_module_map(self) -> Dict[str, ModuleType]: def load_dialects(self, ctx): amd.load_dialects(ctx) + def get_attrs_descriptor(self, params, args): + return HIPAttrsDescriptor(params, args) + + @staticmethod + def compute_spec_key(arg, align): + return HIPAttrsDescriptor.get_property_key(arg, align) + @staticmethod def path_to_rocm_lld(): # Check env path for ld.lld