Skip to content

Commit 0d44a36

Browse files
authored
[FRONTEND] Ragged TMA atomic add (#8238)
Adds ragged TMA atomic add support to triton <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 905b3d1 commit 0d44a36

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

python/test/unit/cuda/test_tma_descriptor.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
import torch
44
import triton
5-
from triton.tools.ragged_tma import create_ragged_descriptor, load_ragged, store_ragged
5+
from triton.tools.ragged_tma import create_ragged_descriptor, atomic_add_ragged, load_ragged, store_ragged
66
from triton.tools.tensor_descriptor import TensorDescriptor
77

88

@@ -55,6 +55,13 @@ def example_load_store_kernel(X, Y, x_off, y_off, x_size, y_size):
5555
store_ragged(Y, y_off, y_size, [0, 0], data)
5656

5757

58+
@triton.jit
59+
def example_load_atomic_add_kernel(X, Y, x_off, y_off, x_size, y_size):
60+
61+
data = load_ragged(X, x_off, x_size, [0, 0])
62+
atomic_add_ragged(Y, y_off, y_size, [0, 0], data)
63+
64+
5865
@pytest.mark.parametrize("dtype", [
5966
"bfloat16", "float16", "float32", "float64", # floating-point
6067
"int8", "int16", "int32", "int64", # signed integers
@@ -66,28 +73,34 @@ def test_ragged_tma(dtype):
6673
pytest.skip("Test requires Hopper or Blackwell target.")
6774
return
6875

76+
test_atomic_add = dtype in ["bfloat16", "float16", "float32", "int32"]
6977
dtype = getattr(torch, dtype)
7078

71-
src = torch.randn((1024, 80), dtype=torch.float32, device="cuda").to(dtype)
79+
src1 = torch.randn((1024, 80), dtype=torch.float32, device="cuda").to(dtype)
80+
src2 = torch.randn((1024, 80), dtype=torch.float32, device="cuda").to(dtype)
7281
ref = torch.randn((1024, 80), dtype=torch.float32, device="cuda").to(dtype)
7382
dst = ref.clone()
7483

75-
X = create_ragged_descriptor(src, [32, 128])
84+
X1 = create_ragged_descriptor(src1, [32, 128])
85+
X2 = create_ragged_descriptor(src2, [32, 128])
7686
Y = create_ragged_descriptor(dst, [32, 128])
7787

7888
x_off = 42
7989
y_off = 51
8090
x_size = 17
8191
y_size = 24
8292

83-
example_load_store_kernel[(1, )](X, Y, x_off, y_off, x_size, y_size)
93+
example_load_store_kernel[(1, )](X1, Y, x_off, y_off, x_size, y_size)
94+
if test_atomic_add:
95+
example_load_atomic_add_kernel[(1, )](X2, Y, x_off, y_off, x_size, y_size)
8496

8597
# the initial and final segments are unchanged:
8698
res0 = torch.equal(dst[:y_off], ref[:y_off])
8799
res1 = torch.equal(dst[y_off + y_size:], ref[y_off + y_size:])
88100

89101
# this segment will be copied verbatim from src:
90-
res2 = torch.equal(dst[y_off:y_off + x_size], src[x_off:x_off + x_size])
102+
ref_tensor = src1 + src2 if test_atomic_add else src1
103+
res2 = torch.equal(dst[y_off:y_off + x_size], ref_tensor[x_off:x_off + x_size])
91104

92105
# this segment will have read OOB zeroes and written them here:
93106
res3 = torch.all(dst[y_off + x_size:y_off + y_size] == 0.0).item()

python/triton/tools/ragged_tma.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,19 @@ def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.con
9090
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
9191
data = tl.reshape(data, [1, 1] + data.shape)
9292
TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
93+
94+
95+
@triton.jit
96+
def atomic_add_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
97+
"""
98+
Atomic add into a subarray T[batch_offset : batch_offset + batch_size] with
99+
hardware bounds-checking, where adds outside the subarray are masked
100+
correctly.
101+
102+
Coords should be an appropriately-sized list of integers, just like in
103+
TMA.atomic_add().
104+
"""
105+
106+
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
107+
data = tl.reshape(data, [1, 1] + data.shape)
108+
TMA.atomic_add([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)

0 commit comments

Comments
 (0)