Skip to content

Commit dd4f237

Browse files
authored
fix paddle.distributed.broadcast/scatter from async to sync (#22)
Signed-off-by: zeroRains <linjunlu@zerorains.top>
1 parent 18391ca commit dd4f237

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

fastsafetensors/tensor_factory.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int, group = Non
106106
if self.metadata.framework == "pytorch":
107107
dist.broadcast(dst, self.rank, group=pg)
108108
elif paddle_loaded and self.metadata.framework == "paddle":
109-
pdist.broadcast(dst, self.rank, group=group, sync_op=False)
109+
pdist.broadcast(dst, self.rank, group=group)
110110
else:
111111
rank_slices: List[Tuple] = [() for i in range(0, pg.size())]
112112
size = frame.shape[dim]
@@ -135,7 +135,7 @@ def shuffle(self, pg: dist.ProcessGroup, tensor_name: str, dim: int, group = Non
135135
if self.metadata.framework == "pytorch":
136136
dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg)
137137
elif paddle_loaded and self.metadata.framework == "paddle":
138-
pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group, sync_op=False)
138+
pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group)
139139
self.shuffled[tensor_name] = dst
140140
return dst
141141

@@ -173,7 +173,7 @@ def shuffle_packed_qkv(self, pg: dist.ProcessGroup, tensor_name: str, group = No
173173
dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg)
174174
elif paddle_loaded and self.metadata.framework == "paddle":
175175
dst = paddle.to_tensor(paddle.empty(shape=new_shape, dtype=frame.dtype),place=self.device)
176-
pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group, sync_op=False)
176+
pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group)
177177
self.shuffled[tensor_name] = dst
178178
return dst
179179

@@ -214,7 +214,7 @@ def shuffle_multi_cols(self, pg: dist.ProcessGroup, tensor_names: List[str], dim
214214
dist.scatter(dst, scatter_list=scatter_list, src=self.rank, group=pg)
215215
elif paddle_loaded and self.metadata.framework == "paddle":
216216
dst = paddle.to_tensor(paddle.empty(shape=new_shape, dtype=frame.dtype), place=self.device )# dst should be eariler than scatter_list for less fragmentation
217-
pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group, sync_op=False)
217+
pdist.scatter(dst, tensor_list=scatter_list, src=self.rank, group=group)
218218
return dst
219219

220220
def free_dev_ptrs(self):

0 commit comments

Comments
 (0)