Skip to content

Commit b1f774c

Browse files
authored
[FRONTEND] Support for write-only ragged TMAs (#7792)
Tested on both H100 and GB200
1 parent 198fd9b commit b1f774c

File tree

2 files changed

+59
-16
lines changed

2 files changed

+59
-16
lines changed

python/test/unit/cuda/test_tma_descriptor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ def example_load_store_kernel(X, Y, x_off, y_off, x_size, y_size):
5555
store_ragged(Y, y_off, y_size, [0, 0], data)
5656

5757

58+
@pytest.mark.parametrize("write_only", [False, True])
5859
@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"])
59-
def test_ragged_tma(dtype):
60+
def test_ragged_tma(dtype, write_only):
6061

6162
if not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 9:
6263
pytest.skip("Test requires Hopper or Blackwell target.")
@@ -69,7 +70,7 @@ def test_ragged_tma(dtype):
6970
dst = 1.0 * ref
7071

7172
X = create_ragged_descriptor(src, [32, 128])
72-
Y = create_ragged_descriptor(dst, [32, 128])
73+
Y = create_ragged_descriptor(dst, [32, 128], write_only=write_only)
7374

7475
x_off = 42
7576
y_off = 51

python/triton/tools/ragged_tma.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,16 @@
44

55
# fmt: off
66

7-
def create_ragged_descriptor(T, block_shape):
7+
class TensorDescriptorPtr:
8+
def __init__(self, data_ptr, dtype):
9+
self._data_ptr = data_ptr
10+
self.dtype = dtype
11+
12+
def data_ptr(self):
13+
return self._data_ptr
14+
15+
16+
def create_ragged_descriptor(T, block_shape, ragged_dim=0, write_only=False):
817
"""
918
Given a 2- or 3-dimensional tensor T, this creates a 'ragged descriptor'
1019
which behaves like a concatenation (along the first axis) of subarrays
@@ -18,22 +27,41 @@ def create_ragged_descriptor(T, block_shape):
1827

1928
block_shape = list(block_shape)
2029
tensor_shape = list(T.shape)
30+
rank = len(tensor_shape)
31+
32+
if ragged_dim < 0:
33+
ragged_dim += rank
2134

22-
assert 2 <= len(tensor_shape) <= 3, "ragged tensors must have dimension 2 or 3"
23-
assert len(tensor_shape) == len(block_shape), "block shape must match tensor shape"
35+
assert 0 <= ragged_dim < rank - 1, "last dimension cannot be ragged"
36+
37+
if write_only:
38+
assert rank <= 4, "write-only ragged descriptors must have at most 4 dimensions"
39+
else:
40+
assert rank <= 3, "read-write ragged descriptors must have at most 3 dimensions"
41+
42+
assert len(block_shape) == rank, "block shape must have same length as tensor shape"
2443

2544
max_int = 0x7fff0000
2645
billion = 0x40000000 # == 2**30
2746

28-
assert tensor_shape[0] <= billion, "number of rows may not exceed 2**30"
47+
assert tensor_shape[ragged_dim] <= billion, "number of rows may not exceed 2**30"
48+
tensor_shape[ragged_dim] = billion
49+
ragged_stride = T.stride(ragged_dim)
2950

3051
# we prepend an extra two dimensions and rely on the fact that pointers
3152
# have 64-bit wraparound semantics:
32-
tma_stride = [2**34 - T.stride(0), T.stride(0)] + [T.stride(i) for i in range(len(tensor_shape))]
33-
tma_shape = [max_int, max_int, billion] + tensor_shape[1:]
53+
tma_stride = [2**34 - ragged_stride, ragged_stride] + [T.stride(i) for i in range(rank)]
54+
tma_shape = [max_int, max_int] + tensor_shape
3455
box_shape = [1, 1] + block_shape
56+
ptr = T.data_ptr()
3557

36-
return TensorDescriptor(T, tma_shape, tma_stride, box_shape)
58+
if write_only:
59+
tma_stride = tma_stride[1:]
60+
tma_shape = tma_shape[1:]
61+
box_shape = box_shape[1:]
62+
ptr = (ptr - billion * ragged_stride * T.element_size()) % (2**64)
63+
64+
return TensorDescriptor(TensorDescriptorPtr(ptr, T.dtype), tma_shape, tma_stride, box_shape)
3765

3866

3967
@triton.jit
@@ -50,7 +78,7 @@ def to_ragged_indices(batch_offset, batch_size, row):
5078

5179

5280
@triton.jit
53-
def load_ragged(TMA, batch_offset, batch_size, coords):
81+
def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr = 0):
5482
"""
5583
Read from a subarray T[batch_offset : batch_offset + batch_size] with
5684
hardware bounds-checking, where reading outside the subarray gives zeros.
@@ -59,14 +87,16 @@ def load_ragged(TMA, batch_offset, batch_size, coords):
5987
TMA.load().
6088
"""
6189

62-
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[0])
63-
data = TMA.load([c0, c1, c2] + coords[1:])
90+
tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor")
91+
92+
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
93+
data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:])
6494
data = tl.reshape(data, data.shape[2:])
6595
return data
6696

6797

6898
@triton.jit
69-
def store_ragged(TMA, batch_offset, batch_size, coords, data):
99+
def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
70100
"""
71101
Write to a subarray T[batch_offset : batch_offset + batch_size] with
72102
hardware bounds-checking, where writes outside the subarray are masked
@@ -76,6 +106,18 @@ def store_ragged(TMA, batch_offset, batch_size, coords, data):
76106
TMA.store().
77107
"""
78108

79-
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[0])
80-
data = tl.reshape(data, [1, 1] + data.shape)
81-
TMA.store([c0, c1, c2] + coords[1:], data)
109+
if len(TMA.shape) == len(coords) + 1:
110+
write_only: tl.constexpr = True
111+
elif len(TMA.shape) == len(coords) + 2:
112+
write_only: tl.constexpr = False
113+
else:
114+
tl.static_assert(False, "TMA must be a ragged descriptor")
115+
116+
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
117+
118+
if write_only:
119+
data = tl.reshape(data, [1] + data.shape)
120+
TMA.store([c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
121+
else:
122+
data = tl.reshape(data, [1, 1] + data.shape)
123+
TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)

0 commit comments

Comments
 (0)