Skip to content

Commit cc40fe0

Browse files
authored
[fix] multi-node backward slowdown (#6134)
* remove redundant memcpy during backward * get back record_stream
1 parent c2fe313 commit cc40fe0

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

colossalai/zero/low_level/bookkeeping/bucket_store.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,13 @@ def build_grad_in_bucket(self):
7878
}
7979
"""
8080
for param, padding_size in zip(self._param_list, self._padding_size):
81-
grad = param.grad.clone().detach().flatten()
81+
grad = param.grad.detach().flatten()
8282
if padding_size > 0:
8383
with torch.no_grad():
8484
grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size])
8585
grad_list = grad.split(grad.numel() // self._world_size)
8686
for rank in range(self._world_size):
87-
grad_current_rank = grad_list[rank].clone().detach()
87+
grad_current_rank = grad_list[rank].detach()
8888
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
8989
self._grad_in_bucket[rank].append(grad_current_rank)
9090
param.grad = None
@@ -110,7 +110,7 @@ def get_flatten_grad(self) -> Tensor:
110110

111111
flat_grad = []
112112
for grad_list in self._grad_in_bucket.values():
113-
flat_grad.append(_flatten_dense_tensors(grad_list))
113+
flat_grad.extend(grad_list)
114114
flat_grad = _flatten_dense_tensors(flat_grad)
115115
return flat_grad
116116

0 commit comments

Comments
 (0)