Skip to content
This repository was archived by the owner on Jan 28, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions msccl/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from msccl.language.ir import Buffer
from msccl.language import *


class Collective:

def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs):
def __init__(self, num_ranks, chunk_factor, inplace, root=0, num_ranks_per_node=-1, **kwargs):
self.num_ranks = num_ranks
self.chunk_factor = chunk_factor
self.inplace = inplace
Expand Down Expand Up @@ -71,6 +70,57 @@ def check(self, prog):
correct = False
return correct

class Broadcast(Collective):
def __init__(self, num_ranks, chunk_factor, inplace, root, create_all_chunks=False):
Collective.__init__(self, num_ranks, chunk_factor, inplace, root)
self.name = "broadcast"
self.root=root

# Initializes input buffer for an broadcast
def init_buffers(self):
rank_buffers = []
if self.inplace:
# Inplace broadcast only uses the input buffer
for r in range(self.num_ranks):
input_buffer = [None] * (self.chunk_factor)
for ch in range(self.chunk_factor):
input_buffer[ch] = Chunk(self.root, ch, -1, ch)
buffers = {
Buffer.input: input_buffer,
Buffer.output: input_buffer,
}
rank_buffers.append(buffers)
else:
for r in range(self.num_ranks):
input_buffer = [None] * self.chunk_factor
output_buffer = [None] * self.chunk_factor
if r==self.root:
for ch in range(self.chunk_factor):
input_buffer[ch] = Chunk(self.root, ch, -1, ch)
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
rank_buffers.append(buffers)
return rank_buffers

# Expected output buffer for broadcast
def check(self, prog):
correct = True
buf = Buffer.output
for r in range(self.num_ranks):
output = prog.buffers[0][buf]
for ch in range(self.chunk_factor):
index = ch
chunk = output[index]
if chunk is None:
print(f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given None")
correct = False
elif chunk.origin_rank != self.root or chunk.origin_index != ch:
print(f"Rank {r} chunk {index} is incorrect should be ({self.root}, {ch}) given ({chunk.origin_rank}, {chunk.origin_index})")
correct = False
return correct


def get_buffer_index(self, rank, buffer, index):
return buffer, index

class AllGather(Collective):
def __init__(self, num_ranks, chunk_factor, inplace, create_all_chunks=False):
Expand Down
Loading