Skip to content
This repository was archived by the owner on Jan 28, 2025. It is now read-only.

Commit 4bef415

Browse files
committed
WIP
1 parent 25744c3 commit 4bef415

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from msccl.language.collectives import AllReduce
2+
3+
def allreduce_ring():
4+
for r in range(0, size):
5+
send_to_peer = (r + 1) % size
6+
recv_from_peer = (r - 1) % size
7+
send_ch = get_channel(src_rank, src_buffer, send_to_peer, dst_buffer, channel_type, tag)
8+
recv_ch = get_channel(src_rank, src_buffer, recv_from_peer, dst_buffer, channel_type, tag)
9+
# for channel the key is (src_rank, src_buffer, dst_rank, dst_buffer, channel_type)
10+
for i in range(0, size - 1):
11+
send_index = (r + i) % size
12+
c = chunk(r, buffer_type, send_index, size)
13+
c.put(send_to_peer, dst_buffer, send_index, sendtbs=[tb_list], ch=ch)
14+
# need barrier to make sure the put is done
15+
rank = get_rank(r)
16+
rank.barrier(tbs=[tb_list])
17+
send_ch.signal(tb=sendtb)
18+
recv_ch.wait(tb=recvtb)

0 commit comments

Comments
 (0)