@@ -106,7 +106,7 @@ def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int, group = Non
106106 if self .metadata .framework == "pytorch" :
107107 dist .broadcast (dst , self .rank , group = pg )
108108 elif paddle_loaded and self .metadata .framework == "paddle" :
109- pdist .broadcast (dst , self .rank , group = group , sync_op = False )
109+ pdist .broadcast (dst , self .rank , group = group )
110110 else :
111111 rank_slices : List [Tuple ] = [() for i in range (0 , pg .size ())]
112112 size = frame .shape [dim ]
@@ -135,7 +135,7 @@ def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int, group = Non
135135 if self .metadata .framework == "pytorch" :
136136 dist .scatter (dst , scatter_list = scatter_list , src = self .rank , group = pg )
137137 elif paddle_loaded and self .metadata .framework == "paddle" :
138- pdist .scatter (dst , tensor_list = scatter_list , src = self .rank , group = group , sync_op = False )
138+ pdist .scatter (dst , tensor_list = scatter_list , src = self .rank , group = group )
139139 self .shuffled [tensor_name ] = dst
140140 return dst
141141
@@ -173,7 +173,7 @@ def shuffle_packed_qkv(self, pg: dist.ProcessGroup, tensor_name: str, group = No
173173 dist .scatter (dst , scatter_list = scatter_list , src = self .rank , group = pg )
174174 elif paddle_loaded and self .metadata .framework == "paddle" :
175175 dst = paddle .to_tensor (paddle .empty (shape = new_shape , dtype = frame .dtype ),place = self .device )
176- pdist .scatter (dst , tensor_list = scatter_list , src = self .rank , group = group , sync_op = False )
176+ pdist .scatter (dst , tensor_list = scatter_list , src = self .rank , group = group )
177177 self .shuffled [tensor_name ] = dst
178178 return dst
179179
@@ -214,7 +214,7 @@ def shuffle_multi_cols(self, pg: dist.ProcessGroup, tensor_names: List[str], dim
214214 dist .scatter (dst , scatter_list = scatter_list , src = self .rank , group = pg )
215215 elif paddle_loaded and self .metadata .framework == "paddle" :
216216 dst = paddle .to_tensor (paddle .empty (shape = new_shape , dtype = frame .dtype ), place = self .device )# dst should be eariler than scatter_list for less fragmentation
217- pdist .scatter (dst , tensor_list = scatter_list , src = self .rank , group = group , sync_op = False )
217+ pdist .scatter (dst , tensor_list = scatter_list , src = self .rank , group = group )
218218 return dst
219219
220220 def free_dev_ptrs (self ):
0 commit comments