Skip to content

Commit 2eefc41

Browse files
Copilotmawad-amd
andcommitted
Fuse K+V put into single kernel; validate on AMD GPUs (16/16 tests pass)
Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
1 parent 32c79d7 commit 2eefc41

File tree

1 file changed

+28
-25
lines changed

1 file changed

+28
-25
lines changed

examples/32_ring_attention/ring_attention_kernels.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,34 @@
1717

1818

1919
@triton.jit
20-
def _put_tensor_kernel(
21-
src,
22-
dst,
20+
def _put_kv_kernel(
21+
k_src,
22+
k_dst,
23+
v_src,
24+
v_dst,
2325
n_elem,
2426
cur_rank: tl.constexpr,
2527
next_rank: tl.constexpr,
2628
heap_bases,
2729
BLOCK: tl.constexpr,
2830
):
2931
"""
30-
Copy a flat tensor from the current rank to the next rank using ``iris.put``.
32+
Fused K+V put: copy K and V to the next rank in a single kernel launch.
3133
32-
Every rank runs this kernel simultaneously. ``dst`` must reside on the Iris
33-
symmetric heap so that the address can be translated to ``next_rank``'s
34-
address space.
34+
Both K and V tensors must be flat (same number of elements) and reside on
35+
the Iris symmetric heap so that their addresses can be translated to
36+
``next_rank``'s address space.
37+
38+
Each program instance copies ``BLOCK`` elements of K **and** ``BLOCK``
39+
elements of V, halving kernel-launch overhead compared to two separate
40+
``_put_tensor_kernel`` calls.
3541
3642
Args:
37-
src: Source pointer (local CUDA memory, does not need to be symmetric).
38-
dst: Destination pointer (must be on the symmetric heap).
39-
n_elem: Total number of elements to copy.
43+
k_src: Source K pointer (must be on the symmetric heap).
44+
k_dst: Destination K pointer (must be on the symmetric heap).
45+
v_src: Source V pointer (must be on the symmetric heap).
46+
v_dst: Destination V pointer (must be on the symmetric heap).
47+
n_elem: Total number of elements in K (same as V).
4048
cur_rank: This rank's ID.
4149
next_rank: Destination rank ID.
4250
heap_bases: Iris heap base address table.
@@ -45,7 +53,8 @@ def _put_tensor_kernel(
4553
pid = tl.program_id(0)
4654
offs = pid * BLOCK + tl.arange(0, BLOCK)
4755
mask = offs < n_elem
48-
iris.put(src + offs, dst + offs, cur_rank, next_rank, heap_bases, mask=mask)
56+
iris.put(k_src + offs, k_dst + offs, cur_rank, next_rank, heap_bases, mask=mask)
57+
iris.put(v_src + offs, v_dst + offs, cur_rank, next_rank, heap_bases, mask=mask)
4958

5059

5160
@triton.jit
@@ -254,7 +263,6 @@ def ring_attn_fwd(q, k, v, shmem, causal=True, scale=None):
254263
# range of tensor sizes and GPU architectures.
255264
PUT_BLOCK = 1024
256265
n_k = k_cur.numel()
257-
n_v = v_cur.numel()
258266
heap_bases = shmem.get_heap_bases()
259267

260268
for step in range(world_size):
@@ -309,24 +317,19 @@ def ring_attn_fwd(q, k, v, shmem, causal=True, scale=None):
309317
num_stages=2,
310318
)
311319

312-
# Rotate K and V to the next rank using Iris put operations.
313-
# All ranks MUST participate in this step so the barrier is well-defined.
314-
# The ping-pong buffers guarantee that the source being read and the
315-
# destination being written are always different allocations.
320+
# Rotate K and V to the next rank using a single fused Iris put kernel.
321+
# Fusing K and V into one kernel launch halves launch overhead and lets
322+
# the GPU overlap their transfers. All ranks MUST participate in this
323+
# step so the barrier is well-defined. The ping-pong buffers guarantee
324+
# that the source being read and the destination being written are always
325+
# different allocations.
316326
if step < world_size - 1:
317-
_put_tensor_kernel[(triton.cdiv(n_k, PUT_BLOCK),)](
327+
_put_kv_kernel[(triton.cdiv(n_k, PUT_BLOCK),)](
318328
k_cur.view(-1),
319329
k_recv.view(-1),
320-
n_k,
321-
cur_rank=rank,
322-
next_rank=next_rank,
323-
heap_bases=heap_bases,
324-
BLOCK=PUT_BLOCK,
325-
)
326-
_put_tensor_kernel[(triton.cdiv(n_v, PUT_BLOCK),)](
327330
v_cur.view(-1),
328331
v_recv.view(-1),
329-
n_v,
332+
n_k,
330333
cur_rank=rank,
331334
next_rank=next_rank,
332335
heap_bases=heap_bases,

0 commit comments

Comments
 (0)