Skip to content

Commit 3288b24

Browse files
authored
Only use Tensor Descriptor indexing with appropriate shapes (#360)
1 parent 60eeb6a commit 3288b24

File tree

3 files changed

+132
-17
lines changed

3 files changed

+132
-17
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import sympy
1010
import torch
11-
import triton
1211

1312
from .. import exc
1413
from .._compat import get_tensor_descriptor_fn_name
@@ -186,10 +185,23 @@ def is_supported(
186185
return False
187186

188187
def valid_block_size(
189-
block_size: int | torch.SymInt | None, stride: int | torch.SymInt
188+
block_size: int | torch.SymInt | None, stride: int | torch.SymInt, idx: int
190189
) -> bool:
191190
if not isinstance(block_size, int):
192191
return False
192+
193+
if (
194+
get_tensor_descriptor_fn_name()
195+
== "tl._experimental_make_tensor_descriptor"
196+
):
197+
# https://github.com/triton-lang/triton/blob/d654e0f2d91f07496454e0fcbec2a9b97df37d47/python/triton/language/semantic.py#L1162
198+
threshold = 32 // fake_tensor.dtype.itemsize
199+
if idx == 0:
200+
threshold = min(8, threshold)
201+
202+
if fake_tensor.ndim == 2 and block_size < threshold:
203+
return False
204+
193205
# was getting some IMAs with small block sizes even in non-stride 1 dims
194206
return block_size * element_size >= 16 or (block_size == 1 and stride != 1)
195207

@@ -198,34 +210,22 @@ def valid_block_size(
198210
strides = fake_tensor.stride()
199211
size_stride = collections.deque(zip(sizes, strides, strict=True))
200212
config = DeviceFunction.current().config
201-
for k in subscript:
213+
for i, k in enumerate(subscript):
202214
if k is None:
203215
continue
204216
size, stride = size_stride.popleft()
205217
if str(k) == "slice(None, None, None)":
206218
block_size = env.allocate_reduction_dimension(size).from_config(config)
207-
if not valid_block_size(block_size, stride):
219+
if not valid_block_size(block_size, stride, i):
208220
return False
209221
elif isinstance(k, torch.SymInt):
210222
block_id = env.get_block_id(k)
211223
if block_id is None:
212224
return False
213225
block_size = env.block_sizes[block_id].from_config(config)
214-
if not valid_block_size(block_size, stride):
226+
if not valid_block_size(block_size, stride, i):
215227
return False
216228

217-
# 5) Extra requirement for experimental version
218-
if get_tensor_descriptor_fn_name() == "tl._experimental_make_tensor_descriptor":
219-
# NOTE: There's no clean way to convert a torch.dtype to triton.dtype
220-
# This is improved in triton 3.4 but tl._experimental_make_tensor_descriptor
221-
# is only available on <= triton 3.3
222-
primitive_bitwidth = getattr(
223-
triton.language, str(fake_tensor.dtype).split(".")[-1]
224-
).primitive_bitwidth
225-
if env.size_hint(sizes[1]) < (32 // primitive_bitwidth) * 8:
226-
# https://github.com/triton-lang/triton/blob/d654e0f2d91f07496454e0fcbec2a9b97df37d47/python/triton/language/semantic.py#L1162
227-
return False
228-
229229
return True
230230

231231
def codegen_load(

test/test_indexing.expected

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,65 @@ def pairwise_add(x: torch.Tensor, *, _launcher=_default_launcher):
261261
_BLOCK_SIZE_0 = 32
262262
_launcher(_pairwise_add_kernel, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, x, out.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
263263
return out
264+
265+
--- assertExpectedJournal(TestIndexing.test_reduction_tensor_descriptor_indexing_block_size)
266+
from __future__ import annotations
267+
268+
import torch
269+
import triton
270+
import triton.language as tl
271+
from helion.runtime import default_launcher as _default_launcher
272+
273+
@triton.jit
274+
def _reduction_sum_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, m, _, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
275+
pid_0 = tl.program_id(0)
276+
offset_0 = pid_0 * _BLOCK_SIZE_0
277+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
278+
mask_0 = indices_0 < m
279+
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
280+
mask_1 = indices_1 < _
281+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
282+
sum_1 = tl.sum(load, 1)
283+
tl.store(out + indices_0 * out_stride_0, sum_1, mask_0)
284+
285+
def reduction_sum(x: torch.Tensor, *, _launcher=_default_launcher):
286+
m, _ = x.size()
287+
out = torch.empty([m], device=x.device, dtype=x.dtype)
288+
_BLOCK_SIZE_0 = 4
289+
_RDIM_SIZE_1 = triton.next_power_of_2(_)
290+
_launcher(_reduction_sum_kernel, (triton.cdiv(m, _BLOCK_SIZE_0),), x, out, out.stride(0), x.stride(0), x.stride(1), m, _, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
291+
return out
292+
293+
--- assertExpectedJournal(TestIndexing.test_reduction_tensor_descriptor_indexing_reduction_loop)
294+
from __future__ import annotations
295+
296+
import torch
297+
import triton
298+
import triton.language as tl
299+
from helion.runtime import default_launcher as _default_launcher
300+
301+
@triton.jit
302+
def _reduction_sum_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, m, _, _BLOCK_SIZE_0: tl.constexpr, _REDUCTION_BLOCK_1: tl.constexpr):
303+
pid_0 = tl.program_id(0)
304+
offset_0 = pid_0 * _BLOCK_SIZE_0
305+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
306+
mask_0 = indices_0 < m
307+
sum_1_acc = tl.full([_BLOCK_SIZE_0, _REDUCTION_BLOCK_1], 0, tl.float32)
308+
for roffset_1 in tl.range(0, _, _REDUCTION_BLOCK_1):
309+
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
310+
mask_1 = rindex_1 < _
311+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
312+
v_0 = load.to(tl.float32)
313+
v_1 = sum_1_acc + v_0
314+
sum_1_acc = v_1
315+
sum_1 = tl.sum(sum_1_acc, 1)
316+
v_2 = sum_1.to(tl.float16)
317+
tl.store(out + indices_0 * out_stride_0, v_2, mask_0)
318+
319+
def reduction_sum(x: torch.Tensor, *, _launcher=_default_launcher):
320+
m, _ = x.size()
321+
out = torch.empty([m], device=x.device, dtype=x.dtype)
322+
_BLOCK_SIZE_0 = 8
323+
_REDUCTION_BLOCK_1 = 8
324+
_launcher(_reduction_sum_kernel, (triton.cdiv(m, _BLOCK_SIZE_0),), x, out, out.stride(0), x.stride(0), x.stride(1), m, _, _BLOCK_SIZE_0, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
325+
return out

test/test_indexing.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ def broadcast_add_3d(
2929
return out
3030

3131

32+
@helion.kernel
33+
def reduction_sum(x: torch.Tensor) -> torch.Tensor:
34+
m, _ = x.size()
35+
out = torch.empty([m], device=x.device, dtype=x.dtype)
36+
for tile in hl.tile(x.size(0)):
37+
out[tile] = x[tile, :].to(torch.float32).sum(-1).to(x.dtype)
38+
39+
return out
40+
41+
3242
class TestIndexing(TestCase):
3343
def test_arange(self):
3444
@helion.kernel
@@ -385,6 +395,49 @@ def test_broadcasting_tensor_descriptor_indexing(self):
385395
torch.testing.assert_close(result, expected)
386396
self.assertExpectedJournal(code)
387397

398+
@unittest.skipIf(not supports_tensor_descriptor(), "TensorDescriptor not supported")
399+
@unittest.skipIf(
400+
get_tensor_descriptor_fn_name() != "tl._experimental_make_tensor_descriptor",
401+
"Not using experimental tensor descriptor",
402+
)
403+
def test_reduction_tensor_descriptor_indexing_block_size(self):
404+
x = torch.randn([64, 64], dtype=torch.float32, device=DEVICE)
405+
406+
# Given block_size 4, tensor_descriptor should not actually be used
407+
# Convert to default pointer indexing
408+
code, result = code_and_output(
409+
reduction_sum,
410+
(x,),
411+
indexing="tensor_descriptor",
412+
block_size=[4],
413+
)
414+
415+
expected = torch.sum(x, dim=1)
416+
torch.testing.assert_close(result, expected)
417+
self.assertExpectedJournal(code)
418+
419+
@unittest.skipIf(not supports_tensor_descriptor(), "TensorDescriptor not supported")
420+
@unittest.skipIf(
421+
get_tensor_descriptor_fn_name() != "tl._experimental_make_tensor_descriptor",
422+
"Not using experimental tensor descriptor",
423+
)
424+
def test_reduction_tensor_descriptor_indexing_reduction_loop(self):
425+
x = torch.randn([64, 256], dtype=torch.float16, device=DEVICE)
426+
427+
# Given reduction_loop 2, # of columns not compatible with tensor_descriptor
428+
# Convert to default pointer indexing
429+
code, result = code_and_output(
430+
reduction_sum,
431+
(x,),
432+
indexing="tensor_descriptor",
433+
block_size=[8],
434+
reduction_loops=[8],
435+
)
436+
437+
expected = torch.sum(x, dim=1)
438+
torch.testing.assert_close(result, expected)
439+
self.assertExpectedJournal(code)
440+
388441

389442
if __name__ == "__main__":
390443
unittest.main()

0 commit comments

Comments
 (0)