Skip to content

Commit cee73e0

Browse files
authored
Add tl._experimental_make_tensor_descriptor restrictions (#331)
1 parent 7d01817 commit cee73e0

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88

99
import sympy
1010
import torch
11+
import triton
1112

1213
from .. import exc
14+
from .._compat import get_tensor_descriptor_fn_name
1315
from .ast_extension import expr_from_string
1416
from .compile_environment import CompileEnvironment
1517
from .device_function import DeviceFunction
@@ -192,9 +194,9 @@ def valid_block_size(
192194
return block_size * element_size >= 16 or (block_size == 1 and stride != 1)
193195

194196
# 4) Check minimum 16 bytes in each dimension
195-
size_stride = collections.deque(
196-
zip(fake_tensor.size(), fake_tensor.stride(), strict=True)
197-
)
197+
sizes = fake_tensor.size()
198+
strides = fake_tensor.stride()
199+
size_stride = collections.deque(zip(sizes, strides, strict=True))
198200
config = DeviceFunction.current().config
199201
for k in subscript:
200202
if k is None:
@@ -212,6 +214,18 @@ def valid_block_size(
212214
if not valid_block_size(block_size, stride):
213215
return False
214216

217+
# 5) Extra requirement for experimental version
218+
if get_tensor_descriptor_fn_name() == "tl._experimental_make_tensor_descriptor":
219+
# NOTE: There's no clean way to convert a torch.dtype to triton.dtype
220+
# This is improved in triton 3.4 but tl._experimental_make_tensor_descriptor
221+
# is only available on <= triton 3.3
222+
primitive_bitwidth = getattr(
223+
triton.language, str(fake_tensor.dtype).split(".")[-1]
224+
).primitive_bitwidth
225+
if env.size_hint(sizes[1]) < (32 // primitive_bitwidth) * 8:
226+
# https://github.com/triton-lang/triton/blob/d654e0f2d91f07496454e0fcbec2a9b97df37d47/python/triton/language/semantic.py#L1162
227+
return False
228+
215229
return True
216230

217231
def codegen_load(

0 commit comments

Comments
 (0)