8
8
9
9
import sympy
10
10
import torch
11
+ import triton
11
12
12
13
from .. import exc
14
+ from .._compat import get_tensor_descriptor_fn_name
13
15
from .ast_extension import expr_from_string
14
16
from .compile_environment import CompileEnvironment
15
17
from .device_function import DeviceFunction
@@ -192,9 +194,9 @@ def valid_block_size(
192
194
return block_size * element_size >= 16 or (block_size == 1 and stride != 1 )
193
195
194
196
# 4) Check minimum 16 bytes in each dimension
195
- size_stride = collections . deque (
196
- zip ( fake_tensor .size (), fake_tensor . stride (), strict = True )
197
- )
197
+ sizes = fake_tensor . size ()
198
+ strides = fake_tensor .stride ()
199
+ size_stride = collections . deque ( zip ( sizes , strides , strict = True ) )
198
200
config = DeviceFunction .current ().config
199
201
for k in subscript :
200
202
if k is None :
@@ -212,6 +214,18 @@ def valid_block_size(
212
214
if not valid_block_size (block_size , stride ):
213
215
return False
214
216
217
+ # 5) Extra requirement for experimental version
218
+ if get_tensor_descriptor_fn_name () == "tl._experimental_make_tensor_descriptor" :
219
+ # NOTE: There's no clean way to convert a torch.dtype to triton.dtype
220
+ # This is improved in triton 3.4 but tl._experimental_make_tensor_descriptor
221
+ # is only available on <= triton 3.3
222
+ primitive_bitwidth = getattr (
223
+ triton .language , str (fake_tensor .dtype ).split ("." )[- 1 ]
224
+ ).primitive_bitwidth
225
+ if env .size_hint (sizes [1 ]) < (32 // primitive_bitwidth ) * 8 :
226
+ # https://github.com/triton-lang/triton/blob/d654e0f2d91f07496454e0fcbec2a9b97df37d47/python/triton/language/semantic.py#L1162
227
+ return False
228
+
215
229
return True
216
230
217
231
def codegen_load (
0 commit comments