Skip to content
This repository was archived by the owner on Jan 28, 2025. It is now read-only.

Commit f3371da

Browse files
mahdiehghazimUbuntuUbuntuSreevatsaAnantharamuBinyang2014
authored
add broadcast support to msccl-tools (#33)
add broadcast support to msccl-tools --------- Co-authored-by: Ubuntu <hpcuser@mahdieh-V53QMP.lku2jwkwlucefebrgtw3pqexke.jx.internal.cloudapp.net> Co-authored-by: Ubuntu <hpcuser@mahdieh-WYRPGZ.lku2jwkwlucefebrgtw3pqexke.jx.internal.cloudapp.net> Co-authored-by: Sreevatsa Anantharamu <sreevatsanadig@gmail.com> Co-authored-by: Binyang Li <binyli@microsoft.com>
1 parent ceaf52f commit f3371da

File tree

1 file changed

+52
-2
lines changed

1 file changed

+52
-2
lines changed

msccl/language/collectives.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
from msccl.language.ir import Buffer
33
from msccl.language import *
44

5-
65
class Collective:
76

8-
def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs):
7+
def __init__(self, num_ranks, chunk_factor, inplace, root=0, num_ranks_per_node=-1, **kwargs):
98
self.num_ranks = num_ranks
109
self.chunk_factor = chunk_factor
1110
self.inplace = inplace
@@ -71,6 +70,57 @@ def check(self, prog):
7170
correct = False
7271
return correct
7372

73+
class Broadcast(Collective):
74+
def __init__(self, num_ranks, chunk_factor, inplace, root, create_all_chunks=False):
75+
Collective.__init__(self, num_ranks, chunk_factor, inplace, root)
76+
self.name = "broadcast"
77+
self.root=root
78+
79+
# Initializes input buffer for an broadcast
80+
def init_buffers(self):
81+
rank_buffers = []
82+
if self.inplace:
83+
# Inplace broadcast only uses the input buffer
84+
for r in range(self.num_ranks):
85+
input_buffer = [None] * (self.chunk_factor)
86+
for ch in range(self.chunk_factor):
87+
input_buffer[ch] = Chunk(self.root, ch, -1, ch)
88+
buffers = {
89+
Buffer.input: input_buffer,
90+
Buffer.output: input_buffer,
91+
}
92+
rank_buffers.append(buffers)
93+
else:
94+
for r in range(self.num_ranks):
95+
input_buffer = [None] * self.chunk_factor
96+
output_buffer = [None] * self.chunk_factor
97+
if r==self.root:
98+
for ch in range(self.chunk_factor):
99+
input_buffer[ch] = Chunk(self.root, ch, -1, ch)
100+
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
101+
rank_buffers.append(buffers)
102+
return rank_buffers
103+
104+
# Expected output buffer for broadcast
105+
def check(self, prog):
106+
correct = True
107+
buf = Buffer.output
108+
for r in range(self.num_ranks):
109+
output = prog.buffers[0][buf]
110+
for ch in range(self.chunk_factor):
111+
index = ch
112+
chunk = output[index]
113+
if chunk is None:
114+
print(f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given None")
115+
correct = False
116+
elif chunk.origin_rank != self.root or chunk.origin_index != ch:
117+
print(f"Rank {r} chunk {index} is incorrect should be ({self.root}, {ch}) given ({chunk.origin_rank}, {chunk.origin_index})")
118+
correct = False
119+
return correct
120+
121+
122+
def get_buffer_index(self, rank, buffer, index):
123+
return buffer, index
74124

75125
class AllGather(Collective):
76126
def __init__(self, num_ranks, chunk_factor, inplace, create_all_chunks=False):

0 commit comments

Comments
 (0)