@@ -138,8 +138,8 @@ def persistent_all_reduce_atomic(
138138 stride_out_m ,
139139 stride_out_n ,
140140 heap_bases : tl .tensor ,
141- cur_rank : tl .constexpr ,
142- cur_rank_global : tl .constexpr ,
141+ group_rank : tl .constexpr ,
142+ iris_rank : tl .constexpr ,
143143 world_size : tl .constexpr ,
144144 rank_start : tl .constexpr ,
145145 rank_stride : tl .constexpr ,
@@ -162,8 +162,8 @@ def persistent_all_reduce_atomic(
162162 M: Number of rows
163163 N: Number of columns
164164 heap_bases: Heap base pointers for all ranks
165- cur_rank: Current rank within the group ( for comparisons)
166- cur_rank_global: Global rank ( for iris IPC operations)
165+ group_rank: Rank within the ProcessGroup (0 to group_size-1), used for tile assignment and comparisons
166+ iris_rank: Rank in the iris context, used for iris RMA operations (heap_bases indexing )
167167 world_size: Total number of ranks in the group
168168 """
169169 pid = tl .program_id (0 )
@@ -210,21 +210,21 @@ def persistent_all_reduce_atomic(
210210 data = tl .load (input_ptr_local , mask = mask )
211211
212212 # Atomically add to output buffer on all ranks
213- # Each rank's output tensor is in its own heap, accessible via IPC
213+ # Each rank's output tensor is in its own heap, accessible via RMA
214214 for i in range (world_size ):
215215 target_rank = rank_start + i * rank_stride
216- if i == cur_rank :
217- # For the current rank (i == rank_in_group ), use local atomic add
216+ if i == group_rank :
217+ # For the current rank (i == group_rank ), use local atomic add
218218 # output_ptr is already in current rank's address space
219219 tl .atomic_add (output_ptr + output_offset , data , mask = mask )
220220 else :
221221 # For remote ranks, use iris.atomic_add to translate pointer
222- # This accesses the remote rank's heap via IPC
223- # Use cur_rank_global for iris operations (heap_bases indexing)
222+ # This accesses the remote rank's heap via RMA
223+ # Use iris_rank for iris operations (heap_bases indexing)
224224 iris .atomic_add (
225225 output_ptr + output_offset ,
226226 data ,
227- cur_rank_global ,
227+ iris_rank ,
228228 target_rank ,
229229 heap_bases ,
230230 mask = mask ,
@@ -245,8 +245,8 @@ def persistent_all_reduce_spinlock(
245245 stride_out_m ,
246246 stride_out_n ,
247247 heap_bases : tl .tensor ,
248- cur_rank : tl .constexpr ,
249- cur_rank_global : tl .constexpr ,
248+ group_rank : tl .constexpr ,
249+ iris_rank : tl .constexpr ,
250250 world_size : tl .constexpr ,
251251 rank_start : tl .constexpr ,
252252 rank_stride : tl .constexpr ,
@@ -310,7 +310,7 @@ def persistent_all_reduce_spinlock(
310310 remote_rank = rank_start + i * rank_stride
311311 partial = iris .load (
312312 input_ptr + input_offset ,
313- cur_rank_global ,
313+ iris_rank ,
314314 remote_rank ,
315315 heap_bases ,
316316 mask = mask ,
@@ -333,8 +333,8 @@ def persistent_all_reduce_one_shot(
333333 stride_out_m ,
334334 stride_out_n ,
335335 heap_bases : tl .tensor ,
336- cur_rank : tl .constexpr ,
337- cur_rank_global : tl .constexpr ,
336+ group_rank : tl .constexpr ,
337+ iris_rank : tl .constexpr ,
338338 world_size : tl .constexpr ,
339339 rank_start : tl .constexpr ,
340340 rank_stride : tl .constexpr ,
@@ -389,7 +389,7 @@ def persistent_all_reduce_one_shot(
389389 remote_rank = rank_start + i * rank_stride
390390 partial = iris .load (
391391 input_ptr + input_offset ,
392- cur_rank_global ,
392+ iris_rank ,
393393 remote_rank ,
394394 heap_bases ,
395395 mask = mask ,
@@ -416,8 +416,8 @@ def persistent_all_reduce_ring(
416416 stride_out_m ,
417417 stride_out_n ,
418418 heap_bases : tl .tensor ,
419- cur_rank : tl .constexpr ,
420- cur_rank_global : tl .constexpr ,
419+ group_rank : tl .constexpr ,
420+ iris_rank : tl .constexpr ,
421421 world_size : tl .constexpr ,
422422 rank_start : tl .constexpr ,
423423 rank_stride : tl .constexpr ,
@@ -504,7 +504,7 @@ def persistent_all_reduce_ring(
504504 remote_flag_ptr ,
505505 0 ,
506506 0 ,
507- cur_rank_global ,
507+ iris_rank ,
508508 next_rank ,
509509 heap_bases ,
510510 sem = "acquire" ,
@@ -517,7 +517,7 @@ def persistent_all_reduce_ring(
517517 iris .store (
518518 ring_buffer + tile_offset ,
519519 send_data ,
520- cur_rank_global ,
520+ iris_rank ,
521521 next_rank ,
522522 heap_bases ,
523523 mask = mask ,
@@ -526,7 +526,7 @@ def persistent_all_reduce_ring(
526526 iris .atomic_xchg (
527527 remote_flag_ptr ,
528528 1 ,
529- cur_rank_global ,
529+ iris_rank ,
530530 next_rank ,
531531 heap_bases ,
532532 sem = "release" ,
@@ -560,8 +560,8 @@ def persistent_all_reduce_two_shot(
560560 stride_out_m ,
561561 stride_out_n ,
562562 heap_bases : tl .tensor ,
563- cur_rank : tl .constexpr ,
564- cur_rank_global : tl .constexpr ,
563+ group_rank : tl .constexpr ,
564+ iris_rank : tl .constexpr ,
565565 world_size : tl .constexpr ,
566566 rank_start : tl .constexpr ,
567567 rank_stride : tl .constexpr ,
@@ -586,13 +586,13 @@ def persistent_all_reduce_two_shot(
586586
587587 tiles_per_rank = tl .cdiv (total_tiles , world_size )
588588 if DISTRIBUTION == 0 :
589- start_tile = cur_rank
589+ start_tile = group_rank
590590 stride = world_size
591591 remaining = total_tiles - start_tile
592592 remaining = tl .maximum (remaining , 0 )
593593 max_tile_offset = tl .cdiv (remaining , stride )
594594 else :
595- start_tile = cur_rank * tiles_per_rank
595+ start_tile = group_rank * tiles_per_rank
596596 stride = 1
597597 remaining = total_tiles - start_tile
598598 remaining = tl .maximum (remaining , 0 )
@@ -636,11 +636,11 @@ def persistent_all_reduce_two_shot(
636636
637637 start_rank_idx = pid % world_size
638638 start_rank_global = rank_start + start_rank_idx * rank_stride
639- acc = iris .load (base_ptr , cur_rank_global , start_rank_global , heap_bases ).to (acc_dtype )
639+ acc = iris .load (base_ptr , iris_rank , start_rank_global , heap_bases ).to (acc_dtype )
640640 for i in tl .static_range (1 , world_size ):
641641 remote_rank_idx = (start_rank_idx + i ) % world_size
642642 remote_rank = rank_start + remote_rank_idx * rank_stride
643- acc += iris .load (base_ptr , cur_rank_global , remote_rank , heap_bases ).to (acc_dtype )
643+ acc += iris .load (base_ptr , iris_rank , remote_rank , heap_bases ).to (acc_dtype )
644644
645645 reduced = acc .to (output_ptr .type .element_ty )
646646
@@ -649,8 +649,8 @@ def persistent_all_reduce_two_shot(
649649 for i in tl .static_range (0 , world_size ):
650650 remote_rank_idx = (start_rank_idx + i ) % world_size
651651 remote_rank = rank_start + remote_rank_idx * rank_stride
652- if remote_rank_idx != cur_rank :
653- iris .store (out_ptr , reduced , cur_rank_global , remote_rank , heap_bases )
652+ if remote_rank_idx != group_rank :
653+ iris .store (out_ptr , reduced , iris_rank , remote_rank , heap_bases )
654654
655655 # Slow path: MASKED (only boundary tiles land here)
656656 # This path handles tiles at tensor boundaries where not all elements are valid.
@@ -659,11 +659,11 @@ def persistent_all_reduce_two_shot(
659659
660660 start_rank_idx = pid % world_size
661661 start_rank_global = rank_start + start_rank_idx * rank_stride
662- acc = iris .load (base_ptr , cur_rank_global , start_rank_global , heap_bases , mask = mask ).to (acc_dtype )
662+ acc = iris .load (base_ptr , iris_rank , start_rank_global , heap_bases , mask = mask ).to (acc_dtype )
663663 for i in tl .static_range (1 , world_size ):
664664 remote_rank_idx = (start_rank_idx + i ) % world_size
665665 remote_rank = rank_start + remote_rank_idx * rank_stride
666- acc += iris .load (base_ptr , cur_rank_global , remote_rank , heap_bases , mask = mask ).to (acc_dtype )
666+ acc += iris .load (base_ptr , iris_rank , remote_rank , heap_bases , mask = mask ).to (acc_dtype )
667667
668668 reduced = acc .to (output_ptr .type .element_ty )
669669
@@ -672,8 +672,8 @@ def persistent_all_reduce_two_shot(
672672 for i in tl .static_range (0 , world_size ):
673673 remote_rank_idx = (start_rank_idx + i ) % world_size
674674 remote_rank = rank_start + remote_rank_idx * rank_stride
675- if remote_rank_idx != cur_rank :
676- iris .store (out_ptr , reduced , cur_rank_global , remote_rank , heap_bases , mask = mask )
675+ if remote_rank_idx != group_rank :
676+ iris .store (out_ptr , reduced , iris_rank , remote_rank , heap_bases , mask = mask )
677677
678678
679679def all_reduce (
@@ -729,8 +729,8 @@ def all_reduce(
729729 )
730730
731731 # Extract group information
732- # rank_in_group: position within the group (0, 1, 2, ...) - used for tile assignment and comparisons
733- # rank_global: global rank across all processes - used for iris IPC operations
732+ # rank_in_group: position within the ProcessGroup (0, 1, 2, ...) - passed as group_rank to kernel
733+ # rank_global: global rank in iris context - passed as iris_rank to kernel for RMA operations
734734 rank_in_group , rank_global , world_size , rank_start , rank_stride = extract_group_info (group , shmem )
735735 M , N = input_tensor .shape [:2 ]
736736
@@ -843,7 +843,7 @@ def all_reduce(
843843 )
844844
845845 # Calculate next rank in the ring for group support
846- # next_rank must be a global rank for iris IPC operations
846+ # next_rank must be a global rank for iris RMA operations
847847 if group is None :
848848 # Simple case: next rank is just (rank_in_group + 1) % world_size (which equals global rank)
849849 next_rank = (rank_in_group + 1 ) % world_size
0 commit comments