Skip to content

Commit 54b3834

Browse files
committed
Revert "Revert "[AMD] Add a tt.pointer_range_32 specialization (#4910)""
This reverts commit 7f80413.
1 parent dd36f6d commit 54b3834

File tree

4 files changed

+115
-15
lines changed

4 files changed

+115
-15
lines changed

python/test/unit/runtime/test_cache.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import triton
1010
import triton.language as tl
1111
from triton.runtime.jit import JITFunction
12+
from triton._internal_testing import is_hip
1213

1314

1415
@triton.jit
@@ -572,3 +573,29 @@ def compiled_hook(*args, **kwargs):
572573
assert specialization_data is not None and specialization_data_compiled == specialization_data
573574
assert is_warmup is True
574575
assert key in kernel_add.cache[getattr(torch, device).current_device()]
576+
577+
578+
@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip())
579+
def test_within_2gb(device, fresh_triton_cache) -> None:
580+
581+
@triton.jit
582+
def kernel_add(a):
583+
tl.load(a)
584+
585+
# This is the attribute we want to test
586+
pointer_range_32 = None
587+
588+
def cache_hook(*args, **kwargs):
589+
nonlocal pointer_range_32
590+
pointer_range_32 = kwargs["compile"]["configs"][0].pointer_range_32
591+
592+
JITFunction.cache_hook = cache_hook
593+
# In warmup we assume that the pointer range is 32 bits
594+
kernel_add.warmup(torch.float32, grid=(1, ))
595+
assert pointer_range_32 == [0]
596+
# Torch tensor > 2GB
597+
kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device))
598+
assert len(pointer_range_32) == 0
599+
# Torch tensor <= 2GB
600+
kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device))
601+
assert pointer_range_32 == [0]

python/triton/backends/compiler.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,21 @@
88
from typing import Dict, Union
99
from types import ModuleType
1010

11+
# Table that associates strings to AttrsDescriptor (sub)classes.
12+
# In this way we can dynamically select the correct class
13+
# constructor
14+
_descriptor_table = {}
1115

16+
17+
def register_descriptor(cls):
18+
"""
19+
Register a descriptor into the descriptor table
20+
"""
21+
_descriptor_table[cls.__name__] = cls
22+
return cls
23+
24+
25+
@register_descriptor
1226
class AttrsDescriptor:
1327
"""
1428
This class handles compile-time properties for specific function parameters.
@@ -135,18 +149,28 @@ def hash(self):
135149
return hashlib.sha256(key.encode("utf-8")).hexdigest()
136150

137151
def to_dict(self):
138-
return self.arg_properties
152+
"""
153+
Store the fields of this class in a serializable dictionary
154+
"""
155+
# We need to only store the `arg_properties` field. To initialize the
156+
# other fields we relay on the class type. We store it as a string in
157+
# the dictionary so that we can use it to invoke the appropriate
158+
# (sub)class constructor in the `from_dict` method.
159+
return {"arg_properties": self.arg_properties, "cls": type(self).__name__}
139160

140161
@staticmethod
141162
def from_dict(data):
142-
attrsDescriptor = AttrsDescriptor()
143-
for prop_name, param_ids in data.items():
144-
attrsDescriptor.arg_properties[prop_name] = param_ids
145-
attrsDescriptor._init_slots()
146-
return attrsDescriptor
147-
148-
@staticmethod
149-
def from_hints(hints: list[tuple[int, int]]):
163+
"""
164+
Create the object from a serializable dictionary
165+
"""
166+
attrs_descriptor = _descriptor_table[data["cls"]]()
167+
for prop_name, param_ids in data["arg_properties"].items():
168+
attrs_descriptor.arg_properties[prop_name] = param_ids
169+
attrs_descriptor._init_slots()
170+
return attrs_descriptor
171+
172+
@classmethod
173+
def from_hints(cls, hints: list[tuple[int, int]]):
150174
"""
151175
Create the class from a set of hints that are passed in.
152176
@@ -156,11 +180,11 @@ def from_hints(hints: list[tuple[int, int]]):
156180
then we insert `param_index` into the correct list (e.g., in
157181
`arg_properties[prop0]`)
158182
"""
159-
attrsDescriptor = AttrsDescriptor()
160-
for prop_name, prop_val in attrsDescriptor.property_values.items():
161-
attrsDescriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
162-
attrsDescriptor._init_slots()
163-
return attrsDescriptor
183+
attrs_descriptor = cls()
184+
for prop_name, prop_val in attrs_descriptor.property_values.items():
185+
attrs_descriptor.arg_properties[prop_name] = [i for i, h in hints.items() if h == prop_val]
186+
attrs_descriptor._init_slots()
187+
return attrs_descriptor
164188

165189
@staticmethod
166190
def is_divisible_by_16(x):

python/triton/runtime/jit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,10 @@ def __init__(self, dtype):
879879
def data_ptr():
880880
return 0 # optimistically assumes multiple of 16
881881

882+
@staticmethod
883+
def ptr_range():
884+
return 0 # optimistically assumes 32 bit pointer range
885+
882886

883887
class TensorWrapper:
884888

third_party/amd/backend/compiler.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from triton.backends.compiler import BaseBackend, GPUTarget
1+
from triton.backends.compiler import BaseBackend, GPUTarget, AttrsDescriptor, register_descriptor
22
from triton._C.libtriton import ir, passes, llvm, amd
33
from dataclasses import dataclass
44
from typing import Any, Dict, Tuple
@@ -72,6 +72,44 @@ def hash(self):
7272
return hashlib.sha256(key.encode("utf-8")).hexdigest()
7373

7474

75+
@register_descriptor
76+
class HIPAttrsDescriptor(AttrsDescriptor):
77+
# This property asserts if the underlying storage area of a given pointer
78+
# can be resepresented as a 32 bit integer. When this is true, we can be
79+
# sure that all indices into the tensor behind that pointer can use 32-bit
80+
# indexing. That opens the door for the AMD backend to use buffer load/store
81+
# instrinsics, which requires this property. Buffer load/store intrinsics
82+
# gives direct out-of-bound support and simplifies index calculation for
83+
# lower register pressure.
84+
__slots__ = ("pointer_range_32")
85+
86+
def _add_backend_properties(self, params=None, values=None):
87+
self.property_values["tt.pointer_range"] = 32
88+
if params is None or values is None:
89+
return
90+
91+
self.arg_properties["tt.pointer_range"] = [
92+
param.num for param, arg in zip(params, values) if HIPAttrsDescriptor.is_within2gb(arg)
93+
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
94+
]
95+
96+
@staticmethod
97+
def is_within2gb(arg):
98+
if hasattr(arg, "ptr_range"):
99+
return arg.ptr_range() <= 2**31 - 1
100+
if "torch.Tensor" in str(type(arg)) and hasattr(arg, "untyped_storage"):
101+
# Please note that 2**31-1 is the max int32 positive limit
102+
return arg.untyped_storage().size() <= 2**31 - 1
103+
return False
104+
105+
@staticmethod
106+
def get_property_key(val, align):
107+
generic_key = AttrsDescriptor.get_property_key(val, align)
108+
hip_key = "S" if HIPAttrsDescriptor.is_within2gb(val) else "N"
109+
key = (generic_key + hip_key).replace("N", "")
110+
return key if key else "N"
111+
112+
75113
class HIPBackend(BaseBackend):
76114

77115
@staticmethod
@@ -118,6 +156,13 @@ def get_module_map(self) -> Dict[str, ModuleType]:
118156
def load_dialects(self, ctx):
119157
amd.load_dialects(ctx)
120158

159+
def get_attrs_descriptor(self, params, args):
160+
return HIPAttrsDescriptor(params, args)
161+
162+
@staticmethod
163+
def compute_spec_key(arg, align):
164+
return HIPAttrsDescriptor.get_property_key(arg, align)
165+
121166
@staticmethod
122167
def path_to_rocm_lld():
123168
# Check env path for ld.lld

0 commit comments

Comments
 (0)