Skip to content

Commit b76f3eb

Browse files
authored
Add 2 Node AllReduce DSL Algorithm (#636)
This PR creates two allreduce algorithms designed for a 2-node environment. These algorithms are in-place and non-zero copy.
1 parent 79e9e61 commit b76f3eb

File tree

1 file changed

+199
-0
lines changed

1 file changed

+199
-0
lines changed
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""
5+
Multi-node AllReduce implementation using packet-based communication.
6+
This implements a hierarchical AllReduce: intra-node allreduce followed by
7+
inter-node exchange and final intra-node allreduce.
8+
"""
9+
10+
import argparse
11+
from mscclpp.language.channel import *
12+
from mscclpp.language.rank import *
13+
from mscclpp.language.general import *
14+
from mscclpp.language.program import *
15+
from mscclpp.language.collectives import *
16+
17+
18+
def allreduce_example(
19+
program_name, gpus_per_node, thread_block_group_size, num_threads_per_block, min_message_size, max_message_size
20+
):
21+
"""
22+
Implements a multi-node AllReduce using a hierarchical approach:
23+
1. Intra-node allreduce
24+
2. Inter-node exchange (exchange reduced data between nodes)
25+
3. Intra-node allreduce
26+
"""
27+
# Configuration constants
28+
num_nodes = 2
29+
total_gpus = num_nodes * gpus_per_node
30+
chunks_per_loop = 1
31+
packets_per_gpu = 2 # Each GPU handles 2 data packets
32+
33+
# Initialize collective operation
34+
collective = AllReduce(total_gpus, chunks_per_loop, True)
35+
36+
with CollectiveProgram(
37+
program_name,
38+
collective,
39+
total_gpus,
40+
protocol="LL",
41+
num_threads_per_block=num_threads_per_block,
42+
reuse_resources=False,
43+
use_double_scratch_buffer=True,
44+
min_message_size=min_message_size,
45+
max_message_size=max_message_size,
46+
):
47+
# Initialize communication channels and buffers
48+
intra_node_memory_channels = {}
49+
inter_node_port_channels = {}
50+
scratch_buffers = []
51+
thread_block_offset = 1
52+
thread_block_group = ThreadBlockGroup(
53+
tb_list=[i for i in range(thread_block_offset, thread_block_offset + thread_block_group_size)]
54+
)
55+
56+
for node_id in range(num_nodes):
57+
for local_gpu_id in range(gpus_per_node):
58+
current_rank_id = local_gpu_id + gpus_per_node * node_id
59+
next_node_rank_id = (local_gpu_id + gpus_per_node * (node_id + 1)) % total_gpus
60+
scratch_buffer_size = 2 * total_gpus
61+
scratch_buffers.append(Buffer(current_rank_id, scratch_buffer_size))
62+
for peer_gpu_id in range(gpus_per_node):
63+
if peer_gpu_id != local_gpu_id:
64+
peer_rank_id = peer_gpu_id + gpus_per_node * node_id
65+
intra_node_memory_channels[(peer_rank_id, current_rank_id)] = MemoryChannel(
66+
peer_rank_id, current_rank_id
67+
)
68+
inter_node_port_channels[current_rank_id] = PortChannel(next_node_rank_id, current_rank_id)
69+
70+
# AllReduce
71+
for node_id in range(num_nodes):
72+
for local_gpu_id in range(gpus_per_node):
73+
current_rank_id = local_gpu_id + gpus_per_node * node_id
74+
current_rank = Rank(current_rank_id)
75+
input_buffer = current_rank.get_input_buffer()
76+
next_node_rank_id = (local_gpu_id + gpus_per_node * (node_id + 1)) % total_gpus
77+
78+
# Intra Node Exchange Data
79+
for peer_gpu_id in range(gpus_per_node):
80+
peer_rank_id = peer_gpu_id + gpus_per_node * node_id
81+
peer_data_offset = peer_gpu_id * packets_per_gpu
82+
if peer_gpu_id != local_gpu_id:
83+
intra_node_memory_channels[(peer_rank_id, current_rank_id)].put_packets(
84+
scratch_buffers[peer_rank_id][
85+
local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu
86+
],
87+
input_buffer[peer_data_offset : peer_data_offset + packets_per_gpu],
88+
tb_group=thread_block_group,
89+
)
90+
91+
# Intra Node Reduce
92+
other_gpu_data = [
93+
scratch_buffers[current_rank_id][
94+
packets_per_gpu * gpu_idx : packets_per_gpu * gpu_idx + packets_per_gpu
95+
]
96+
for gpu_idx in range(gpus_per_node)
97+
if gpu_idx != local_gpu_id
98+
]
99+
current_rank.reduce(
100+
input_buffer[local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu],
101+
other_gpu_data,
102+
tb_group=thread_block_group,
103+
packet=True,
104+
)
105+
106+
# Copy Reduced Data to Scratch Buffer and send to Next Node
107+
current_rank.copy_packets(
108+
scratch_buffers[current_rank_id][
109+
local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu
110+
],
111+
input_buffer[local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu],
112+
tb_group=thread_block_group,
113+
)
114+
inter_node_offset = total_gpus
115+
inter_node_port_channels[current_rank_id].read_put_packets(
116+
scratch_buffers[next_node_rank_id][
117+
inter_node_offset
118+
+ local_gpu_id * packets_per_gpu : inter_node_offset
119+
+ local_gpu_id * packets_per_gpu
120+
+ packets_per_gpu
121+
],
122+
scratch_buffers[current_rank_id][
123+
local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu
124+
],
125+
tb=0,
126+
)
127+
128+
# Reduce Received Data from Remote Node
129+
inter_node_data = [
130+
scratch_buffers[current_rank_id][
131+
inter_node_offset
132+
+ local_gpu_id * packets_per_gpu : inter_node_offset
133+
+ local_gpu_id * packets_per_gpu
134+
+ packets_per_gpu
135+
]
136+
]
137+
current_rank.reduce(
138+
input_buffer[local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu],
139+
inter_node_data,
140+
tb_group=thread_block_group,
141+
packet=True,
142+
)
143+
144+
# Broadcast Reduced Data
145+
for peer_gpu_id in range(gpus_per_node):
146+
peer_rank_id = peer_gpu_id + gpus_per_node * node_id
147+
148+
if peer_gpu_id != local_gpu_id:
149+
intra_node_memory_channels[(peer_rank_id, current_rank_id)].put_packets(
150+
scratch_buffers[peer_rank_id][
151+
inter_node_offset
152+
+ local_gpu_id * packets_per_gpu : inter_node_offset
153+
+ local_gpu_id * packets_per_gpu
154+
+ packets_per_gpu
155+
],
156+
input_buffer[
157+
local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu
158+
],
159+
tb_group=thread_block_group,
160+
)
161+
162+
# Unpack Data Received from other GPUs in the same node
163+
for peer_gpu_id in range(gpus_per_node):
164+
if peer_gpu_id != local_gpu_id:
165+
current_rank.unpack_packets(
166+
input_buffer[
167+
peer_gpu_id * packets_per_gpu : peer_gpu_id * packets_per_gpu + packets_per_gpu
168+
],
169+
scratch_buffers[current_rank_id][
170+
inter_node_offset
171+
+ peer_gpu_id * packets_per_gpu : inter_node_offset
172+
+ peer_gpu_id * packets_per_gpu
173+
+ packets_per_gpu
174+
],
175+
tb_group=thread_block_group,
176+
)
177+
178+
print(JSON())
179+
180+
181+
parser = argparse.ArgumentParser()
182+
183+
parser.add_argument("--name", type=str, help="name of the program")
184+
parser.add_argument("--gpus_per_node", type=int, help="number of gpus per node")
185+
parser.add_argument("--tbg_size", type=int, help="number of thread blocks in the thread block group")
186+
parser.add_argument("--num_threads_per_block", type=int, default=1024, help="number of threads per block")
187+
parser.add_argument("--min_message_size", type=int, default=0, help="minimum message size")
188+
parser.add_argument("--max_message_size", type=int, default=2 * 2**20, help="maximum message size")
189+
190+
args = parser.parse_args()
191+
192+
allreduce_example(
193+
args.name,
194+
args.gpus_per_node,
195+
args.tbg_size,
196+
args.num_threads_per_block,
197+
args.min_message_size,
198+
args.max_message_size,
199+
)

0 commit comments

Comments
 (0)