Skip to content

Commit 8a56497

Browse files
caiomcbrCaio Rocha
andauthored
Updating MSCCLLang Examples (#462)
Co-authored-by: Caio Rocha <aiorocha@microsoft.com>
1 parent 55789bc commit 8a56497

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

python/examples/allreduce_allpairs_packet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def allreduce_allpairs(gpus, instances):
3535
remote_rank = tb
3636
index = remote_rank * size
3737
c = chunk(r1, Buffer.input, index, size)
38-
c.put_packet(remote_rank, "scratch", index=r1 * size, sendtb=tb)
38+
c.put_packet(remote_rank, Buffer.scratch, index=r1 * size, sendtb=tb)
3939

4040
# Each rank performs a local reduction on the nth chunk
4141
# Utilize 8 threadblocks for this reduction for better parallelism
@@ -44,16 +44,16 @@ def allreduce_allpairs(gpus, instances):
4444
c = chunk(r, Buffer.input, r * size + index)
4545
for peer in range(size):
4646
if peer != r:
47-
c.reduce_packet(chunk(r, "scratch", peer * size + index), recvtb=index)
47+
c.reduce_packet(chunk(r, Buffer.scratch, peer * size + index), recvtb=index)
4848
for peer in range(size):
4949
if peer != r:
50-
c.put_packet(peer, "scratch", (size * size) + r * size + index, sendtb=index)
50+
c.put_packet(peer, Buffer.scratch, (size * size) + r * size + index, sendtb=index)
5151

5252
# Each rank get final result from scratch space
5353
for r in range(size):
5454
for peer in range(size):
5555
if peer != r:
56-
c = chunk(r, "scratch", size * size + peer * size, size)
56+
c = chunk(r, Buffer.scratch, size * size + peer * size, size)
5757
c.copy_packet(r, Buffer.input, peer * size, sendtb=peer)
5858

5959
Json()

python/examples/send_recv_packet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,16 @@ def send_recv(instances):
3333
c = chunk(r, Buffer.input, 0)
3434
c.put_packet(
3535
nghr,
36-
"scratch",
36+
Buffer.scratch,
3737
1,
3838
sendtb=0,
3939
chan_type=ChannelType.port,
40-
temp_buffer="scratch",
40+
temp_buffer=Buffer.scratch,
4141
temp_buffer_index=0,
4242
)
4343

4444
for r in range(size):
45-
c = chunk(r, "scratch", 1)
45+
c = chunk(r, Buffer.scratch, 1)
4646
c.copy_packet(r, Buffer.output, 0, sendtb=0)
4747

4848
Json()

python/examples/send_recv_proxy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ def send_recv(instances):
3131
c = chunk(r, Buffer.input, 0)
3232
c.put(
3333
nghr,
34-
"scratch",
34+
Buffer.scratch,
3535
1,
3636
sendtb=0,
3737
chan_type=ChannelType.port,
3838
)
39-
c.signal(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.port)
40-
c.flush(nghr, "scratch", 1, sendtb=0, chan_type=ChannelType.port)
39+
c.signal(nghr, Buffer.scratch, 1, sendtb=0, chan_type=ChannelType.port)
40+
c.flush(nghr, Buffer.scratch, 1, sendtb=0, chan_type=ChannelType.port)
4141

4242
for r in range(size):
43-
c = chunk(r, "scratch", 1)
43+
c = chunk(r, Buffer.scratch, 1)
4444
c.wait(1 - r, Buffer.input, 0, recvtb=0, chan_type=ChannelType.port)
4545
c.copy(r, Buffer.output, 0, sendtb=0)
4646

0 commit comments

Comments
 (0)