Skip to content

Commit 87ec25b

Browse files
committed
Added a few things to the NCCL communicator
1 parent 31d83a2 commit 87ec25b

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

pySDC/helpers/NCCL_communicator.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __getattr__(self, name):
2727
Args:
2828
Name (str): Name of the requested attribute
2929
"""
30-
if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split']:
30+
if name not in ['size', 'rank', 'Get_rank', 'Get_size', 'Split', 'Create_cart', 'Is_inter', 'Get_topology']:
3131
cp.cuda.get_current_stream().synchronize()
3232

3333
return getattr(self.commMPI, name)
@@ -71,6 +71,26 @@ def get_op(self, MPI_op):
7171
else:
7272
raise NotImplementedError('Don\'t know what NCCL operation to use to replace this MPI operation!')
7373

74+
def reduce(self, sendobj, op=MPI.SUM, root=0):
75+
sync = False
76+
if hasattr(sendobj, 'data'):
77+
if hasattr(sendobj.data, 'ptr'):
78+
sync = True
79+
if sync:
80+
cp.cuda.Device().synchronize()
81+
82+
return self.commMPI.reduce(sendobj, op=op, root=root)
83+
84+
def allreduce(self, sendobj, op=MPI.SUM):
85+
sync = False
86+
if hasattr(sendobj, 'data'):
87+
if hasattr(sendobj.data, 'ptr'):
88+
sync = True
89+
if sync:
90+
cp.cuda.Device().synchronize()
91+
92+
return self.commMPI.allreduce(sendobj, op=op)
93+
7494
def Reduce(self, sendbuf, recvbuf, op=MPI.SUM, root=0):
7595
if not hasattr(sendbuf.data, 'ptr'):
7696
return self.commMPI.Reduce(sendbuf=sendbuf, recvbuf=recvbuf, op=op, root=root)
@@ -113,3 +133,7 @@ def Bcast(self, buf, root=0):
113133
stream = cp.cuda.get_current_stream()
114134

115135
self.commNCCL.bcast(buff=buf.data.ptr, count=count, datatype=dtype, root=root, stream=stream.ptr)
136+
137+
def Barrier(self):
138+
cp.cuda.get_current_stream().synchronize()
139+
self.commMPI.Barrier()

0 commit comments

Comments
 (0)