|
2 | 2 | from msccl.language.ir import Buffer |
3 | 3 | from msccl.language import * |
4 | 4 |
|
5 | | - |
6 | 5 | class Collective: |
7 | 6 |
|
8 | | - def __init__(self, num_ranks, chunk_factor, inplace, num_ranks_per_node=-1, **kwargs): |
| 7 | + def __init__(self, num_ranks, chunk_factor, inplace, root=0, num_ranks_per_node=-1, **kwargs): |
9 | 8 | self.num_ranks = num_ranks |
10 | 9 | self.chunk_factor = chunk_factor |
11 | 10 | self.inplace = inplace |
@@ -71,6 +70,57 @@ def check(self, prog): |
71 | 70 | correct = False |
72 | 71 | return correct |
73 | 72 |
|
| 73 | +class Broadcast(Collective): |
| 74 | + def __init__(self, num_ranks, chunk_factor, inplace, root, create_all_chunks=False): |
| 75 | + Collective.__init__(self, num_ranks, chunk_factor, inplace, root) |
| 76 | + self.name = "broadcast" |
| 77 | + self.root=root |
| 78 | + |
| 79 | + # Initializes input buffer for an broadcast |
| 80 | + def init_buffers(self): |
| 81 | + rank_buffers = [] |
| 82 | + if self.inplace: |
| 83 | + # Inplace broadcast only uses the input buffer |
| 84 | + for r in range(self.num_ranks): |
| 85 | + input_buffer = [None] * (self.chunk_factor) |
| 86 | + for ch in range(self.chunk_factor): |
| 87 | + input_buffer[ch] = Chunk(self.root, ch, -1, ch) |
| 88 | + buffers = { |
| 89 | + Buffer.input: input_buffer, |
| 90 | + Buffer.output: input_buffer, |
| 91 | + } |
| 92 | + rank_buffers.append(buffers) |
| 93 | + else: |
| 94 | + for r in range(self.num_ranks): |
| 95 | + input_buffer = [None] * self.chunk_factor |
| 96 | + output_buffer = [None] * self.chunk_factor |
| 97 | + if r==self.root: |
| 98 | + for ch in range(self.chunk_factor): |
| 99 | + input_buffer[ch] = Chunk(self.root, ch, -1, ch) |
| 100 | + buffers = {Buffer.input: input_buffer, Buffer.output: output_buffer} |
| 101 | + rank_buffers.append(buffers) |
| 102 | + return rank_buffers |
| 103 | + |
| 104 | + # Expected output buffer for broadcast |
| 105 | + def check(self, prog): |
| 106 | + correct = True |
| 107 | + buf = Buffer.output |
| 108 | + for r in range(self.num_ranks): |
| 109 | + output = prog.buffers[0][buf] |
| 110 | + for ch in range(self.chunk_factor): |
| 111 | + index = ch |
| 112 | + chunk = output[index] |
| 113 | + if chunk is None: |
| 114 | + print(f"Rank {r} chunk {index} is incorrect should be ({i}, {ch}) given None") |
| 115 | + correct = False |
| 116 | + elif chunk.origin_rank != self.root or chunk.origin_index != ch: |
| 117 | + print(f"Rank {r} chunk {index} is incorrect should be ({self.root}, {ch}) given ({chunk.origin_rank}, {chunk.origin_index})") |
| 118 | + correct = False |
| 119 | + return correct |
| 120 | + |
| 121 | + |
| 122 | + def get_buffer_index(self, rank, buffer, index): |
| 123 | + return buffer, index |
74 | 124 |
|
75 | 125 | class AllGather(Collective): |
76 | 126 | def __init__(self, num_ranks, chunk_factor, inplace, create_all_chunks=False): |
|
0 commit comments