Skip to content

Commit 4c89cf7

Browse files
committed
fix a sync issue and use async to overlap muon update and all gather
1 parent 1a35008 commit 4c89cf7

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

deepspeed/runtime/zero/stage3.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,8 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b
15241524
(world_sz - len(params) % world_sz) % world_sz)
15251525
gathered_momentums_pad = gathered_params_momentums + [torch.empty_like(gathered_params_momentums[-1])] * (
15261526
(world_sz - len(gathered_params_momentums) % world_sz) % world_sz)
1527+
grad_handles = []
1528+
momentum_handles = []
15271529
for base_i in range(len(params))[::world_sz]:
15281530
if base_i + rank < len(params):
15291531
param = params[base_i + rank]
@@ -1534,10 +1536,14 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b
15341536
_, _, grad_offset = group_items[base_i + rank]
15351537
buffer_to_reduce.narrow(0, grad_offset,
15361538
param.grad.numel()).data.copy_(g.view(-1), non_blocking=False)
1537-
dist.all_gather(grads_pad[base_i:base_i + world_sz], grads_pad[base_i + rank])
1538-
dist.all_gather(gathered_momentums_pad[base_i:base_i + world_sz],
1539-
gathered_momentums_pad[base_i + rank])
1540-
1539+
grad_handle = dist.all_gather(grads_pad[base_i:base_i + world_sz], grads_pad[base_i + rank], async_op=True)
1540+
grad_handles.append(grad_handle)
1541+
momentum_handle = dist.all_gather(gathered_momentums_pad[base_i:base_i + world_sz],
1542+
gathered_momentums_pad[base_i + rank], async_op=True)
1543+
momentum_handles.append(momentum_handle)
1544+
1545+
for handle in momentum_handles:
1546+
handle.wait()
15411547
for idx, (param, dest_offset, _) in enumerate(group_items):
15421548
gathered_momentum = gathered_params_momentums[idx]
15431549
chunk_sz = math.ceil(param.grad.numel() / world_sz)
@@ -1564,6 +1570,11 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b
15641570
)
15651571
if self._swappable_optimizer_subgroup(i) and not self.save_muon_momentum_buffer_in_memory:
15661572
self.optimizer_swapper.swap_out_optimizer_state(parameter=self.fp32_partitioned_groups_flat[i])
1573+
for handle in grad_handles:
1574+
handle.wait()
1575+
for param, _, params_size_offset in group_items:
1576+
buffer_to_reduce.narrow(0, params_size_offset,
1577+
param.grad.numel()).data.copy_(param.grad.view(-1), non_blocking=False)
15671578

15681579
@instrument_w_nvtx
15691580
def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor,

0 commit comments

Comments
 (0)