@@ -27,7 +27,7 @@ def __getattr__(self, name):
2727 Args:
2828 Name (str): Name of the requested attribute
2929 """
30- if name not in ['size' , 'rank' , 'Get_rank' , 'Get_size' , 'Split' ]:
30+ if name not in ['size' , 'rank' , 'Get_rank' , 'Get_size' , 'Split' , 'Create_cart' , 'Is_inter' , 'Get_topology' ]:
3131 cp .cuda .get_current_stream ().synchronize ()
3232
3333 return getattr (self .commMPI , name )
@@ -71,6 +71,26 @@ def get_op(self, MPI_op):
7171 else :
7272 raise NotImplementedError ('Don\' t know what NCCL operation to use to replace this MPI operation!' )
7373
74+ def reduce (self , sendobj , op = MPI .SUM , root = 0 ):
75+ sync = False
76+ if hasattr (sendobj , 'data' ):
77+ if hasattr (sendobj .data , 'ptr' ):
78+ sync = True
79+ if sync :
80+ cp .cuda .Device ().synchronize ()
81+
82+ return self .commMPI .reduce (sendobj , op = op , root = root )
83+
84+ def allreduce (self , sendobj , op = MPI .SUM ):
85+ sync = False
86+ if hasattr (sendobj , 'data' ):
87+ if hasattr (sendobj .data , 'ptr' ):
88+ sync = True
89+ if sync :
90+ cp .cuda .Device ().synchronize ()
91+
92+ return self .commMPI .allreduce (sendobj , op = op )
93+
7494 def Reduce (self , sendbuf , recvbuf , op = MPI .SUM , root = 0 ):
7595 if not hasattr (sendbuf .data , 'ptr' ):
7696 return self .commMPI .Reduce (sendbuf = sendbuf , recvbuf = recvbuf , op = op , root = root )
@@ -113,3 +133,7 @@ def Bcast(self, buf, root=0):
113133 stream = cp .cuda .get_current_stream ()
114134
115135 self .commNCCL .bcast (buff = buf .data .ptr , count = count , datatype = dtype , root = root , stream = stream .ptr )
136+
137+ def Barrier (self ):
138+ cp .cuda .get_current_stream ().synchronize ()
139+ self .commMPI .Barrier ()
0 commit comments