Skip to content

Commit 51e225e

Browse files
Merge branch 'master' into fix/bf16-zero3-quantized-weights
2 parents c072800 + 156dcb2 commit 51e225e

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def clear(self):
118118
self.params.clear()
119119
self.grads.clear()
120120
self.elements = 0
121-
self.index = 0
122121
self.has_moe_params = False
123122

124123

@@ -1052,11 +1051,8 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
10521051
bucket = self.ipg_buckets[comm_dtype]
10531052
if bucket.elements + param.numel() > self.reduce_bucket_size:
10541053
self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel())
1055-
self.reduce_ipg_grads()
1054+
self.reduce_ipg_grads(comm_dtype=comm_dtype)
10561055
if self.contiguous_gradients and self.overlap_comm:
1057-
if not get_accelerator().resolves_data_dependency():
1058-
self.reduction_stream.wait_stream(get_accelerator().current_stream())
1059-
get_accelerator().current_stream().wait_stream(self.reduction_stream)
10601056
# Swap index between 0 and 1
10611057
bucket.index = 1 - bucket.index
10621058
self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel())
@@ -1500,8 +1496,11 @@ def copy_grads_in_partition(self, param):
15001496
#print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}")
15011497
self.grads_in_partition_offset += param.numel()
15021498

1503-
def reduce_ipg_grads(self):
1504-
for comm_dtype in sort_dtypes(self.ipg_buckets.keys()):
1499+
def reduce_ipg_grads(self, comm_dtype=None):
1500+
dtypes = sort_dtypes(self.ipg_buckets.keys())
1501+
if comm_dtype is not None:
1502+
dtypes = [comm_dtype]
1503+
for comm_dtype in dtypes:
15051504
bucket = self.ipg_buckets[comm_dtype]
15061505

15071506
if self.contiguous_gradients:
@@ -1536,7 +1535,7 @@ def reduce_ipg_grads(self):
15361535
stream = get_accelerator().current_stream()
15371536

15381537
with get_accelerator().stream(stream):
1539-
for comm_dtype in sort_dtypes(self.ipg_buckets.keys()):
1538+
for comm_dtype in dtypes:
15401539
bucket = self.ipg_buckets[comm_dtype]
15411540

15421541
for group_idx, param_idx_in_group, param_id in bucket.params:

tests/unit/ops/deepspeed4science/test_DS4Sci_EvoformerAttention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def attention_reference(
4040
return o
4141

4242

43+
@pytest.mark.sequential
4344
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
4445
@pytest.mark.parametrize("tensor_shape", [(1, 256, 256, 4, 32), (1, 512, 256, 8, 8)])
4546
def test_DS4Sci_EvoformerAttention(dtype, tensor_shape):

tests/unit/runtime/zero/test_zero_tensor_fragment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def test_bf16_optimizer_fragments(self, frozen_weights):
179179
"grad_accum_dtype": "fp32"
180180
},
181181
"zero_optimization": {
182-
"stage": 0,
182+
"stage": 1,
183183
}
184184
}
185185

0 commit comments

Comments
 (0)