Skip to content

Commit 15ad92b

Browse files
Fix ping-pong buffer index reset and removing redundant stream sync (deepspeedai#7805)
Fix deepspeedai#7804 and deepspeedai#7188 After investigating the code in `deepspeed/runtime/zero/stage_1_and_2.py`, I have identified the root cause. The regression regarding communication overlap was introduced in PR deepspeedai#7371 (deepspeedai#7371). While the additional two-stream synchronization in that PR fixes gradient corruption, it effectively disables the overlapping behavior. The underlying issue causing the gradient corruption (which deepspeedai#7371 attempted to fix) was actually introduced in PR deepspeedai#6993 (deepspeedai#6993). In that PR, `bucket.clear()` incorrectly resets the ping-pong buffer index to 0 at the end of `reduce_ipg_grads`. This logic disrupts the buffer index swapping mechanism within `reduce_independent_p_g_buckets_and_remove_grads`. To fix this, L121 in `deepspeed/runtime/zero/stage_1_and_2.py` should be removed to prevent resetting the buffer index. Additionally, the stream synchronization logic introduced in deepspeedai#7371 should be removed to restore the `overlap_comm=True` functionality. --------- Signed-off-by: szlent <metarufolds@gmail.com> Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Masahiro Tanaka <mtanaka@anyscale.com>
1 parent 3bc882f commit 15ad92b

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def clear(self):
118118
self.params.clear()
119119
self.grads.clear()
120120
self.elements = 0
121-
self.index = 0
122121
self.has_moe_params = False
123122

124123

@@ -1052,11 +1051,8 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
10521051
bucket = self.ipg_buckets[comm_dtype]
10531052
if bucket.elements + param.numel() > self.reduce_bucket_size:
10541053
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel())
1055-
self.reduce_ipg_grads()
1054+
self.reduce_ipg_grads(comm_dtype=comm_dtype)
10561055
if self.contiguous_gradients and self.overlap_comm:
1057-
if not get_accelerator().resolves_data_dependency():
1058-
self.reduction_stream.wait_stream(get_accelerator().current_stream())
1059-
get_accelerator().current_stream().wait_stream(self.reduction_stream)
10601056
# Swap index between 0 and 1
10611057
bucket.index = 1 - bucket.index
10621058
self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel())
@@ -1500,8 +1496,11 @@ def copy_grads_in_partition(self, param):
15001496
#print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}")
15011497
self.grads_in_partition_offset += param.numel()
15021498

1503-
def reduce_ipg_grads(self):
1504-
for comm_dtype in sort_dtypes(self.ipg_buckets.keys()):
1499+
def reduce_ipg_grads(self, comm_dtype=None):
1500+
dtypes = sort_dtypes(self.ipg_buckets.keys())
1501+
if comm_dtype is not None:
1502+
dtypes = [comm_dtype]
1503+
for comm_dtype in dtypes:
15051504
bucket = self.ipg_buckets[comm_dtype]
15061505

15071506
if self.contiguous_gradients:
@@ -1536,7 +1535,7 @@ def reduce_ipg_grads(self):
15361535
stream = get_accelerator().current_stream()
15371536

15381537
with get_accelerator().stream(stream):
1539-
for comm_dtype in sort_dtypes(self.ipg_buckets.keys()):
1538+
for comm_dtype in dtypes:
15401539
bucket = self.ipg_buckets[comm_dtype]
15411540

15421541
for group_idx, param_idx_in_group, param_id in bucket.params:

0 commit comments

Comments
 (0)