Skip to content

Commit ef53205

Browse files
committed
Better naming, trying to do some cleanup.
1 parent a19c859 commit ef53205

File tree

6 files changed

+261
-84
lines changed

6 files changed

+261
-84
lines changed

iris/ccl/all_gather.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def persistent_all_gather(
2424
stride_out_m,
2525
stride_out_n,
2626
heap_bases: tl.tensor,
27-
cur_rank: tl.constexpr,
28-
cur_rank_global: tl.constexpr,
27+
group_rank: tl.constexpr,
28+
iris_rank: tl.constexpr,
2929
world_size: tl.constexpr,
3030
rank_start: tl.constexpr,
3131
rank_stride: tl.constexpr,
@@ -51,8 +51,8 @@ def persistent_all_gather(
5151
stride_in_m, stride_in_n: Strides for input tensor
5252
stride_out_m, stride_out_n: Strides for output tensor
5353
heap_bases: Heap base pointers for all ranks
54-
cur_rank: Current rank within the group (for comparisons)
55-
cur_rank_global: Rank within the `iris` instance
54+
group_rank: Rank within the ProcessGroup (0 to group_size-1), used for tile assignment and comparisons
55+
iris_rank: Rank in the iris context, used for iris RMA operations (heap_bases indexing)
5656
world_size: Total number of ranks in the group
5757
BLOCK_SIZE_M, BLOCK_SIZE_N: Block sizes for tiling
5858
GROUP_SIZE_M: Group size for M dimension tiling
@@ -104,15 +104,15 @@ def persistent_all_gather(
104104
data = tl.load(input_ptr_source, mask=input_mask, other=0.0)
105105

106106
# Send local shard data to all destination ranks
107-
# Each rank's input goes to output[cur_rank * M : (cur_rank + 1) * M, :] on all ranks
107+
# Each rank's input goes to output[group_rank * M : (group_rank + 1) * M, :] on all ranks
108108
for i in tl.static_range(world_size):
109109
target_rank = rank_start + i * rank_stride
110110

111-
# Compute global output row indices: offset by cur_rank * M
112-
rm_output = rm_input + cur_rank * M
111+
# Compute global output row indices: offset by group_rank * M
112+
rm_output = rm_input + group_rank * M
113113

114114
# Output mask: only write where input was valid
115-
output_mask = (rm_output[:, None] < (cur_rank + 1) * M) & (rn[None, :] < N)
115+
output_mask = (rm_output[:, None] < (group_rank + 1) * M) & (rn[None, :] < N)
116116

117117
# Combine masks: must be valid in both input and output
118118
combined_mask = input_mask & output_mask
@@ -124,16 +124,16 @@ def persistent_all_gather(
124124
output_ptr_target = output_ptr + output_offset
125125
output_ptr_target = tl.multiple_of(output_ptr_target, (BLOCK_SIZE_M, BLOCK_SIZE_N))
126126

127-
if i == cur_rank:
128-
# Local destination (i == rank_in_group): use direct store
127+
if i == group_rank:
128+
# Local destination (i == group_rank): use direct store
129129
tl.store(output_ptr_target, data, mask=combined_mask, cache_modifier=".wt")
130130
else:
131131
# Remote destination: use iris.store to send data to remote destination
132-
# Use cur_rank_global for iris IPC operations
132+
# Use iris_rank for iris RMA operations (heap_bases indexing)
133133
iris.store(
134134
output_ptr_target,
135135
data,
136-
cur_rank_global,
136+
iris_rank,
137137
target_rank,
138138
heap_bases,
139139
mask=combined_mask,
@@ -183,8 +183,8 @@ def all_gather(
183183
)
184184

185185
# Extract group information
186-
# rank_in_group: position within the group (0, 1, 2, ...) - used for comparisons
187-
# rank_global: global rank across all processes - used for iris RMA operations
186+
# rank_in_group: position within the ProcessGroup (0, 1, 2, ...) - passed as group_rank to kernel
187+
# rank_global: global rank in iris context - passed as iris_rank to kernel for RMA operations
188188
rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, shmem)
189189

190190
M, N = input_tensor.shape[:2]

iris/ccl/all_reduce.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def persistent_all_reduce_atomic(
138138
stride_out_m,
139139
stride_out_n,
140140
heap_bases: tl.tensor,
141-
cur_rank: tl.constexpr,
142-
cur_rank_global: tl.constexpr,
141+
group_rank: tl.constexpr,
142+
iris_rank: tl.constexpr,
143143
world_size: tl.constexpr,
144144
rank_start: tl.constexpr,
145145
rank_stride: tl.constexpr,
@@ -162,8 +162,8 @@ def persistent_all_reduce_atomic(
162162
M: Number of rows
163163
N: Number of columns
164164
heap_bases: Heap base pointers for all ranks
165-
cur_rank: Current rank within the group (for comparisons)
166-
cur_rank_global: Global rank (for iris IPC operations)
165+
group_rank: Rank within the ProcessGroup (0 to group_size-1), used for tile assignment and comparisons
166+
iris_rank: Rank in the iris context, used for iris RMA operations (heap_bases indexing)
167167
world_size: Total number of ranks in the group
168168
"""
169169
pid = tl.program_id(0)
@@ -210,21 +210,21 @@ def persistent_all_reduce_atomic(
210210
data = tl.load(input_ptr_local, mask=mask)
211211

212212
# Atomically add to output buffer on all ranks
213-
# Each rank's output tensor is in its own heap, accessible via IPC
213+
# Each rank's output tensor is in its own heap, accessible via RMA
214214
for i in range(world_size):
215215
target_rank = rank_start + i * rank_stride
216-
if i == cur_rank:
217-
# For the current rank (i == rank_in_group), use local atomic add
216+
if i == group_rank:
217+
# For the current rank (i == group_rank), use local atomic add
218218
# output_ptr is already in current rank's address space
219219
tl.atomic_add(output_ptr + output_offset, data, mask=mask)
220220
else:
221221
# For remote ranks, use iris.atomic_add to translate pointer
222-
# This accesses the remote rank's heap via IPC
223-
# Use cur_rank_global for iris operations (heap_bases indexing)
222+
# This accesses the remote rank's heap via RMA
223+
# Use iris_rank for iris operations (heap_bases indexing)
224224
iris.atomic_add(
225225
output_ptr + output_offset,
226226
data,
227-
cur_rank_global,
227+
iris_rank,
228228
target_rank,
229229
heap_bases,
230230
mask=mask,
@@ -245,8 +245,8 @@ def persistent_all_reduce_spinlock(
245245
stride_out_m,
246246
stride_out_n,
247247
heap_bases: tl.tensor,
248-
cur_rank: tl.constexpr,
249-
cur_rank_global: tl.constexpr,
248+
group_rank: tl.constexpr,
249+
iris_rank: tl.constexpr,
250250
world_size: tl.constexpr,
251251
rank_start: tl.constexpr,
252252
rank_stride: tl.constexpr,
@@ -310,7 +310,7 @@ def persistent_all_reduce_spinlock(
310310
remote_rank = rank_start + i * rank_stride
311311
partial = iris.load(
312312
input_ptr + input_offset,
313-
cur_rank_global,
313+
iris_rank,
314314
remote_rank,
315315
heap_bases,
316316
mask=mask,
@@ -333,8 +333,8 @@ def persistent_all_reduce_one_shot(
333333
stride_out_m,
334334
stride_out_n,
335335
heap_bases: tl.tensor,
336-
cur_rank: tl.constexpr,
337-
cur_rank_global: tl.constexpr,
336+
group_rank: tl.constexpr,
337+
iris_rank: tl.constexpr,
338338
world_size: tl.constexpr,
339339
rank_start: tl.constexpr,
340340
rank_stride: tl.constexpr,
@@ -389,7 +389,7 @@ def persistent_all_reduce_one_shot(
389389
remote_rank = rank_start + i * rank_stride
390390
partial = iris.load(
391391
input_ptr + input_offset,
392-
cur_rank_global,
392+
iris_rank,
393393
remote_rank,
394394
heap_bases,
395395
mask=mask,
@@ -416,8 +416,8 @@ def persistent_all_reduce_ring(
416416
stride_out_m,
417417
stride_out_n,
418418
heap_bases: tl.tensor,
419-
cur_rank: tl.constexpr,
420-
cur_rank_global: tl.constexpr,
419+
group_rank: tl.constexpr,
420+
iris_rank: tl.constexpr,
421421
world_size: tl.constexpr,
422422
rank_start: tl.constexpr,
423423
rank_stride: tl.constexpr,
@@ -504,7 +504,7 @@ def persistent_all_reduce_ring(
504504
remote_flag_ptr,
505505
0,
506506
0,
507-
cur_rank_global,
507+
iris_rank,
508508
next_rank,
509509
heap_bases,
510510
sem="acquire",
@@ -517,7 +517,7 @@ def persistent_all_reduce_ring(
517517
iris.store(
518518
ring_buffer + tile_offset,
519519
send_data,
520-
cur_rank_global,
520+
iris_rank,
521521
next_rank,
522522
heap_bases,
523523
mask=mask,
@@ -526,7 +526,7 @@ def persistent_all_reduce_ring(
526526
iris.atomic_xchg(
527527
remote_flag_ptr,
528528
1,
529-
cur_rank_global,
529+
iris_rank,
530530
next_rank,
531531
heap_bases,
532532
sem="release",
@@ -560,8 +560,8 @@ def persistent_all_reduce_two_shot(
560560
stride_out_m,
561561
stride_out_n,
562562
heap_bases: tl.tensor,
563-
cur_rank: tl.constexpr,
564-
cur_rank_global: tl.constexpr,
563+
group_rank: tl.constexpr,
564+
iris_rank: tl.constexpr,
565565
world_size: tl.constexpr,
566566
rank_start: tl.constexpr,
567567
rank_stride: tl.constexpr,
@@ -586,13 +586,13 @@ def persistent_all_reduce_two_shot(
586586

587587
tiles_per_rank = tl.cdiv(total_tiles, world_size)
588588
if DISTRIBUTION == 0:
589-
start_tile = cur_rank
589+
start_tile = group_rank
590590
stride = world_size
591591
remaining = total_tiles - start_tile
592592
remaining = tl.maximum(remaining, 0)
593593
max_tile_offset = tl.cdiv(remaining, stride)
594594
else:
595-
start_tile = cur_rank * tiles_per_rank
595+
start_tile = group_rank * tiles_per_rank
596596
stride = 1
597597
remaining = total_tiles - start_tile
598598
remaining = tl.maximum(remaining, 0)
@@ -636,11 +636,11 @@ def persistent_all_reduce_two_shot(
636636

637637
start_rank_idx = pid % world_size
638638
start_rank_global = rank_start + start_rank_idx * rank_stride
639-
acc = iris.load(base_ptr, cur_rank_global, start_rank_global, heap_bases).to(acc_dtype)
639+
acc = iris.load(base_ptr, iris_rank, start_rank_global, heap_bases).to(acc_dtype)
640640
for i in tl.static_range(1, world_size):
641641
remote_rank_idx = (start_rank_idx + i) % world_size
642642
remote_rank = rank_start + remote_rank_idx * rank_stride
643-
acc += iris.load(base_ptr, cur_rank_global, remote_rank, heap_bases).to(acc_dtype)
643+
acc += iris.load(base_ptr, iris_rank, remote_rank, heap_bases).to(acc_dtype)
644644

645645
reduced = acc.to(output_ptr.type.element_ty)
646646

@@ -649,8 +649,8 @@ def persistent_all_reduce_two_shot(
649649
for i in tl.static_range(0, world_size):
650650
remote_rank_idx = (start_rank_idx + i) % world_size
651651
remote_rank = rank_start + remote_rank_idx * rank_stride
652-
if remote_rank_idx != cur_rank:
653-
iris.store(out_ptr, reduced, cur_rank_global, remote_rank, heap_bases)
652+
if remote_rank_idx != group_rank:
653+
iris.store(out_ptr, reduced, iris_rank, remote_rank, heap_bases)
654654

655655
# Slow path: MASKED (only boundary tiles land here)
656656
# This path handles tiles at tensor boundaries where not all elements are valid.
@@ -659,11 +659,11 @@ def persistent_all_reduce_two_shot(
659659

660660
start_rank_idx = pid % world_size
661661
start_rank_global = rank_start + start_rank_idx * rank_stride
662-
acc = iris.load(base_ptr, cur_rank_global, start_rank_global, heap_bases, mask=mask).to(acc_dtype)
662+
acc = iris.load(base_ptr, iris_rank, start_rank_global, heap_bases, mask=mask).to(acc_dtype)
663663
for i in tl.static_range(1, world_size):
664664
remote_rank_idx = (start_rank_idx + i) % world_size
665665
remote_rank = rank_start + remote_rank_idx * rank_stride
666-
acc += iris.load(base_ptr, cur_rank_global, remote_rank, heap_bases, mask=mask).to(acc_dtype)
666+
acc += iris.load(base_ptr, iris_rank, remote_rank, heap_bases, mask=mask).to(acc_dtype)
667667

668668
reduced = acc.to(output_ptr.type.element_ty)
669669

@@ -672,8 +672,8 @@ def persistent_all_reduce_two_shot(
672672
for i in tl.static_range(0, world_size):
673673
remote_rank_idx = (start_rank_idx + i) % world_size
674674
remote_rank = rank_start + remote_rank_idx * rank_stride
675-
if remote_rank_idx != cur_rank:
676-
iris.store(out_ptr, reduced, cur_rank_global, remote_rank, heap_bases, mask=mask)
675+
if remote_rank_idx != group_rank:
676+
iris.store(out_ptr, reduced, iris_rank, remote_rank, heap_bases, mask=mask)
677677

678678

679679
def all_reduce(
@@ -729,8 +729,8 @@ def all_reduce(
729729
)
730730

731731
# Extract group information
732-
# rank_in_group: position within the group (0, 1, 2, ...) - used for tile assignment and comparisons
733-
# rank_global: global rank across all processes - used for iris IPC operations
732+
# rank_in_group: position within the ProcessGroup (0, 1, 2, ...) - passed as group_rank to kernel
733+
# rank_global: global rank in iris context - passed as iris_rank to kernel for RMA operations
734734
rank_in_group, rank_global, world_size, rank_start, rank_stride = extract_group_info(group, shmem)
735735
M, N = input_tensor.shape[:2]
736736

@@ -843,7 +843,7 @@ def all_reduce(
843843
)
844844

845845
# Calculate next rank in the ring for group support
846-
# next_rank must be a global rank for iris IPC operations
846+
# next_rank must be a global rank for iris RMA operations
847847
if group is None:
848848
# Simple case: next rank is just (rank_in_group + 1) % world_size (which equals global rank)
849849
next_rank = (rank_in_group + 1) % world_size

0 commit comments

Comments
 (0)