1717
1818
1919@triton .jit
20- def _put_tensor_kernel (
21- src ,
22- dst ,
20+ def _put_kv_kernel (
21+ k_src ,
22+ k_dst ,
23+ v_src ,
24+ v_dst ,
2325 n_elem ,
2426 cur_rank : tl .constexpr ,
2527 next_rank : tl .constexpr ,
2628 heap_bases ,
2729 BLOCK : tl .constexpr ,
2830):
2931 """
30- Copy a flat tensor from the current rank to the next rank using ``iris.put`` .
32+ Fused K+V put: copy K and V to the next rank in a single kernel launch .
3133
32- Every rank runs this kernel simultaneously. ``dst`` must reside on the Iris
33- symmetric heap so that the address can be translated to ``next_rank``'s
34- address space.
34+ Both K and V tensors must be flat (same number of elements) and reside on
35+ the Iris symmetric heap so that their addresses can be translated to
36+ ``next_rank``'s address space.
37+
38+ Each program instance copies ``BLOCK`` elements of K **and** ``BLOCK``
39+ elements of V, halving kernel-launch overhead compared to two separate
40+ ``_put_tensor_kernel`` calls.
3541
3642 Args:
37- src: Source pointer (local CUDA memory, does not need to be symmetric).
38- dst: Destination pointer (must be on the symmetric heap).
39- n_elem: Total number of elements to copy.
43+ k_src: Source K pointer (must be on the symmetric heap).
44+ k_dst: Destination K pointer (must be on the symmetric heap).
45+ v_src: Source V pointer (must be on the symmetric heap).
46+ v_dst: Destination V pointer (must be on the symmetric heap).
47+ n_elem: Total number of elements in K (same as V).
4048 cur_rank: This rank's ID.
4149 next_rank: Destination rank ID.
4250 heap_bases: Iris heap base address table.
@@ -45,7 +53,8 @@ def _put_tensor_kernel(
4553 pid = tl .program_id (0 )
4654 offs = pid * BLOCK + tl .arange (0 , BLOCK )
4755 mask = offs < n_elem
48- iris .put (src + offs , dst + offs , cur_rank , next_rank , heap_bases , mask = mask )
56+ iris .put (k_src + offs , k_dst + offs , cur_rank , next_rank , heap_bases , mask = mask )
57+ iris .put (v_src + offs , v_dst + offs , cur_rank , next_rank , heap_bases , mask = mask )
4958
5059
5160@triton .jit
@@ -254,7 +263,6 @@ def ring_attn_fwd(q, k, v, shmem, causal=True, scale=None):
254263 # range of tensor sizes and GPU architectures.
255264 PUT_BLOCK = 1024
256265 n_k = k_cur .numel ()
257- n_v = v_cur .numel ()
258266 heap_bases = shmem .get_heap_bases ()
259267
260268 for step in range (world_size ):
@@ -309,24 +317,19 @@ def ring_attn_fwd(q, k, v, shmem, causal=True, scale=None):
309317 num_stages = 2 ,
310318 )
311319
312- # Rotate K and V to the next rank using Iris put operations.
313- # All ranks MUST participate in this step so the barrier is well-defined.
314- # The ping-pong buffers guarantee that the source being read and the
315- # destination being written are always different allocations.
320+ # Rotate K and V to the next rank using a single fused Iris put kernel.
321+ # Fusing K and V into one kernel launch halves launch overhead and lets
322+ # the GPU overlap their transfers. All ranks MUST participate in this
323+ # step so the barrier is well-defined. The ping-pong buffers guarantee
324+ # that the source being read and the destination being written are always
325+ # different allocations.
316326 if step < world_size - 1 :
317- _put_tensor_kernel [(triton .cdiv (n_k , PUT_BLOCK ),)](
327+ _put_kv_kernel [(triton .cdiv (n_k , PUT_BLOCK ),)](
318328 k_cur .view (- 1 ),
319329 k_recv .view (- 1 ),
320- n_k ,
321- cur_rank = rank ,
322- next_rank = next_rank ,
323- heap_bases = heap_bases ,
324- BLOCK = PUT_BLOCK ,
325- )
326- _put_tensor_kernel [(triton .cdiv (n_v , PUT_BLOCK ),)](
327330 v_cur .view (- 1 ),
328331 v_recv .view (- 1 ),
329- n_v ,
332+ n_k ,
330333 cur_rank = rank ,
331334 next_rank = next_rank ,
332335 heap_bases = heap_bases ,
0 commit comments