Skip to content
This repository was archived by the owner on Jan 28, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions examples/mscclang/mscclpp/simple/allgather_ring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2024 Advanced Micro Devices.
# Licensed under the MIT License.

import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import AllGather

# Implementation of ring-based AllGather where data is being pushed using put()
def allgather_ring_push(size, in_place):
topology = fully_connected(size)
collective = AllGather(size, 1, in_place)
with MSCCLPPProgram(f"allgather_ring_push", topology, collective, 1):
# If not in-place copy local data chunk to output buffer
if not in_place:
for rank in range(0, size):
c = chunk(rank, Buffer.input, 0)
c.copy(rank, Buffer.output, rank)
# Iterate over steps
for step in range(0, size - 1):
for rank in range(0, size):
# Put & Signal
index = (rank - step) % size
c = chunk(rank, Buffer.output, index)
next_rank = (rank + 1) % size
c.put(next_rank, Buffer.output, index, sendtb=0) # TODO how does this guarantee that the buffer is ready?
c.signal(next_rank, Buffer.output, index, 0)
# Wait
prev_rank = (rank - 1) % size
recv_index = (rank - step - 1) % size
c = chunk(rank, Buffer.output, recv_index)
c.wait(prev_rank, Buffer.output, recv_index, 0)
Json()
Check()

# Implementation of ring-based AllGather where data is being pulled using get()
def allgather_ring_pull(size, in_place):
topology = fully_connected(size)
collective = AllGather(size, 1, in_place)
with MSCCLPPProgram(f"allgather_ring_pull", topology, collective, 1):
# If not in-place copy local data chunk to output buffer
if not in_place:
for rank in range(0, size):
c = chunk(rank, Buffer.input, 0)
c.copy(rank, Buffer.output, rank)
# Iterate over steps
for step in range(0, size - 1): # size - 1):
for rank in range(0, size):
# Signal
index = (rank - step) % size
c = chunk(rank, Buffer.output, index)
next_rank = (rank + 1) % size
c.signal(next_rank, Buffer.output, index, 0)
# Wait & Get
prev_rank = (rank - 1) % size
recv_index = (rank - step - 1) % size
c = chunk(rank, Buffer.output, recv_index)
c.wait(prev_rank, Buffer.output, recv_index, 0)
c.get(prev_rank, Buffer.output, recv_index, recvtb=0)
Json()
Check()

parser = argparse.ArgumentParser()
parser.add_argument('num_gpus', type=int, help ='number of gpus')
parser.add_argument('--in-place', type=bool, default=True, help='Do collective in-place?')

args = parser.parse_args()

# allgather_ring_push(args.num_gpus, args.in_place)
allgather_ring_pull(args.num_gpus, args.in_place)
4 changes: 2 additions & 2 deletions msccl/language/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def init_buffers(self):
if self.inplace:
# Inplace AllGather only uses the output buffer
for r in range(self.num_ranks):
output_buffer = [None] * (self.num_ranks * self.chunk_factor)
output_buffer = [Chunk(-1, -1, -1, -1)] * (self.num_ranks * self.chunk_factor)
for ch in range(self.chunk_factor):
output_buffer[r * self.chunk_factor + ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch)
buffers = {
Expand All @@ -83,7 +83,7 @@ def init_buffers(self):
else:
for r in range(self.num_ranks):
input_buffer = [None] * self.chunk_factor
output_buffer = [None] * (self.num_ranks * self.chunk_factor)
output_buffer = [Chunk(-1, -1, -1, -1)] * (self.num_ranks * self.chunk_factor)
for ch in range(self.chunk_factor):
input_buffer[ch] = Chunk(r, ch, -1, r * self.chunk_factor + ch)
buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer}
Expand Down
2 changes: 1 addition & 1 deletion msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def remove_empty_fields(d):
gpus.append(gpu_instance)
obj = {
"name": program.name,
"colletive": program.collective,
"collective": program.collective,
"protocol": program.protocol,
"inplace": program.inplace,
"gpus": gpus,
Expand Down