diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 01468d87e42a..ea213cffd423 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -4408,3 +4408,10 @@ def reload_states(self, non_blocking: bool = False) -> None: DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer." self.optimizer.reload_states(non_blocking=non_blocking) + + def set_all_reduce_hook( + self, + hook: Callable[[torch.Tensor], None], + ): + if hasattr(self.optimizer, "_all_reduce_hook"): + self.optimizer._all_reduce_hook = hook diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 0828bd7c755b..29a6d1700f21 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -488,6 +488,8 @@ def _enforce_optimizer_offload(): self.offloaded_states: Set[OffloadDeviceEnum] = set() + self._all_reduce_hook: Optional[Callable[[torch.Tensor], None]] = None + if dist.get_rank(group=self.dp_process_group) == 0: see_memory_usage("After initializing ZeRO optimizer", force=True) @@ -1406,7 +1408,11 @@ def __reduce_and_partition_ipg_grads(self, communication_data_type: torch.dtype) if self.zenflow and self.micro_step >= self.full_warm_up_rounds: self._process_selected_fp32_groups_grad(params_in_bucket, grad_partitions) - self.partition_grads(params_in_bucket, grad_partitions) + grad_buffers = self.partition_grads(params_in_bucket, grad_partitions) + + if self._all_reduce_hook and self.is_gradient_accumulation_boundary: + for grad_buffer in grad_buffers: + self._all_reduce_hook(grad_buffer) params_in_bucket.clear() bucket.elements = 0 diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 86f50a1a0c0b..5baccad6165d 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team +from typing import Optional, Callable import torch from deepspeed import comm as dist @@ -630,6 +631,7 @@ def _enforce_cpu_offload(): self._create_optimizer_mapping() self.offloaded_states: Set[OffloadStateTypeEnum] = set() + self._all_reduce_hook: Optional[Callable[[torch.Tensor], None]] = None def destroy(self): for i, _ in enumerate(self.optimizer.param_groups): @@ -1480,6 +1482,10 @@ def copy_grads_in_partition(self, param): #print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}") self.grads_in_partition_offset += param.numel() + def all_reduce_hook(self, tensor): + if self._all_reduce_hook: + self._all_reduce_hook(tensor) + def reduce_ipg_grads(self): for comm_dtype in sort_dtypes(self.ipg_buckets.keys()): bucket = self.ipg_buckets[comm_dtype] @@ -1516,6 +1522,7 @@ def reduce_ipg_grads(self): stream = get_accelerator().current_stream() with get_accelerator().stream(stream): + grad_buffers = [] for comm_dtype in sort_dtypes(self.ipg_buckets.keys()): bucket = self.ipg_buckets[comm_dtype] @@ -1536,12 +1543,22 @@ def reduce_ipg_grads(self): self.previous_reduced_grads[comm_dtype].append(param) else: self.clear_grad_attribute(param) - elif self.contiguous_gradients: - self.copy_grads_in_partition(param) + else: + if self.contiguous_gradients: + self.copy_grads_in_partition(param) + grad_buffers.append(self.get_gradient_for_reduction(param)) else: # zero stage 1 - partition only optimizer state - if self.contiguous_gradients and self.is_param_in_current_partition[param_id]: - self.copy_grads_in_partition(param) + if self.is_param_in_current_partition[param_id]: + if self.contiguous_gradients: + self.copy_grads_in_partition(param) + grad_buffers.append(self.get_gradient_for_reduction(param)) + bucket.clear() + + if self._all_reduce_hook: + for grad_buffer in grad_buffers: + if grad_buffer is not None: + self.all_reduce_hook(grad_buffer) ##################################################################### def process_gradients(self, param, i): diff --git a/tests/unit/runtime/zero/test_zero_allreduce_hook.py b/tests/unit/runtime/zero/test_zero_allreduce_hook.py new file mode 100644 index 000000000000..bcab1911db4f --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_allreduce_hook.py @@ -0,0 +1,223 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +import deepspeed.comm as dist +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader +from deepspeed.accelerator import get_accelerator + + +@pytest.mark.parametrize("zero_stage", [1, 2, 3]) +@pytest.mark.parametrize("contiguous_gradients", [True, False]) +@pytest.mark.parametrize("reduce_bucket_size", [500000000, 10]) +@pytest.mark.parametrize("reduce_scatter", [True, False]) +@pytest.mark.parametrize("overlap_comm", [True, False]) +@pytest.mark.parametrize("gradient_accumulation_steps", [1, 4]) +class TestZeroAllReduceHook(DistributedTest): + """Test _all_reduce_hook functionality for ZeRO stage 1, 2 and 3""" + world_size = 4 # 4 processes to simulate 2 replica groups + + def test(self, zero_stage, contiguous_gradients, reduce_bucket_size, reduce_scatter, overlap_comm, + gradient_accumulation_steps): + """ + Test that _all_reduce_hook is called correctly and performs cross-replica gradient sync. + + Setup: + - 4 processes split into 2 replica groups + - Replica group 0: ranks [0, 1] + - Replica group 1: ranks [2, 3] + - Same initial parameters across all ranks + - Different training data per replica group + + Verification: + - Hook is called with gradient tensors + - All ranks have identical model parameters after training (proves gradient sync works) + """ + + rank = dist.get_rank() + + # Create replica groups + replica_group_0_ranks = [0, 1] + replica_group_1_ranks = [2, 3] + + replica_group_0 = dist.new_group(ranks=replica_group_0_ranks) + replica_group_1 = dist.new_group(ranks=replica_group_1_ranks) + + if rank in replica_group_0_ranks: + replica_dp_group = replica_group_0 + replica_id = 0 + else: + replica_dp_group = replica_group_1 + replica_id = 1 + + # Create cross-replica groups for gradient synchronization + # IMPORTANT: All ranks must call dist.new_group() for all groups! + cross_replica_group_0 = dist.new_group(ranks=[0, 2]) # All 4 ranks must call this + cross_replica_group_1 = dist.new_group(ranks=[1, 3]) # All 4 ranks must call this + + local_rank_in_replica = rank % 2 + if local_rank_in_replica == 0: + cross_replica_group = cross_replica_group_0 + else: + cross_replica_group = cross_replica_group_1 + + # Create a custom MPU object to specify replica-specific DP group + # This is crucial for stage 3 to ensure parameters are sharded correctly + class ReplicaMPU: + """Custom MPU that provides replica-specific data parallel group""" + + def __init__(self, dp_group): + self._dp_group = dp_group + + def get_data_parallel_group(self): + return self._dp_group + + def get_data_parallel_world_size(self): + return dist.get_world_size(group=self._dp_group) + + def get_data_parallel_rank(self): + return dist.get_rank(group=self._dp_group) + + def get_model_parallel_world_size(self): + """Return 1 as we don't use model parallelism""" + return 1 + + def get_model_parallel_rank(self): + """Return 0 as we don't use model parallelism""" + return 0 + + def get_model_parallel_group(self): + """Return None as we don't use model parallelism""" + return None + + replica_mpu = ReplicaMPU(replica_dp_group) + + # Track hook invocations + hook_call_count = [0] + hook_tensors = [] + + def cross_replica_gradient_sync_hook(tensor): + """Hook that averages gradients across replica groups""" + hook_call_count[0] += 1 + hook_tensors.append(tensor.clone().detach()) + # Synchronize gradients across replica groups + dist.all_reduce(tensor, op=dist.ReduceOp.AVG, group=cross_replica_group) + + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": gradient_accumulation_steps, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + "contiguous_gradients": contiguous_gradients, + "reduce_bucket_size": reduce_bucket_size, + "reduce_scatter": reduce_scatter, + "overlap_comm": overlap_comm, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.01 + } + }, + } + + # Stage 3 specific configuration + if zero_stage == 3: + config_dict["zero_optimization"]["stage3_param_persistence_threshold"] = 0 + + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + + hidden_dim = 10 + + # Create model with same initial parameters for all ranks + torch.manual_seed(42) # Same seed for all ranks + model = SimpleModel(hidden_dim=hidden_dim) + + # Pass the replica_mpu to deepspeed.initialize so that parameters + # are sharded according to the replica-specific DP group from the start + model, _, _, _ = deepspeed.initialize( + config=config_dict, + model=model, + model_parameters=model.parameters(), + dist_init_required=False, + mpu=replica_mpu, + ) + + # Set the all_reduce_hook + model.set_all_reduce_hook(cross_replica_gradient_sync_hook) + + # Create different data for different replica groups + torch.manual_seed(42 + replica_id) # Different seed for different replicas + + # Ensure we have enough samples for all steps + # Each step consumes train_micro_batch_size_per_gpu samples + num_steps = 3 * gradient_accumulation_steps + train_batch_size = config_dict["train_micro_batch_size_per_gpu"] + total_samples_needed = num_steps * train_batch_size + data_loader = random_dataloader( + model=model, + total_samples=total_samples_needed + train_batch_size, # Extra samples for safety + hidden_dim=hidden_dim, + device=model.device, + ) + + # Reset counters + hook_call_count[0] = 0 + hook_tensors.clear() + + # Train for a few steps + for step_id, batch in enumerate(data_loader): + if step_id >= num_steps: + break + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + # Verify hook was called + assert hook_call_count[0] > 0, \ + f"Hook should be called for stage={zero_stage}, contiguous={contiguous_gradients}, bucket_size={reduce_bucket_size}, reduce_scatter={reduce_scatter}, overlap_comm={overlap_comm}" + + # Verify tensors were passed to hook + assert len(hook_tensors) > 0, "Hook should receive gradient tensors" + + # Verify all tensors are valid + non_empty_tensors = [t for t in hook_tensors if t.numel() > 0] + assert len(non_empty_tensors) > 0, \ + f"At least some hook tensors should have elements. Got {len(hook_tensors)} total tensors, all empty." + + for tensor in non_empty_tensors: + assert tensor is not None, "Hook tensor should not be None" + assert tensor.device.type == get_accelerator().device_name(), \ + f"Tensor should be on {get_accelerator().device_name()}" + + # Synchronize before checking parameters + dist.barrier() + + # Verify that all ranks have identical model parameters + # This proves cross-replica gradient synchronization worked + if zero_stage == 3: + with deepspeed.zero.GatheredParameters(list(model.parameters()), modifier_rank=None): + param_list = [p.data.clone() for p in model.parameters()] + else: + param_list = [p.data.clone() for p in model.parameters()] + + for param_idx, param in enumerate(param_list): + gathered_params = [torch.zeros_like(param) for _ in range(self.world_size)] + dist.all_gather(gathered_params, param) + + if rank == 0: + for other_rank in range(1, self.world_size): + assert torch.allclose(gathered_params[0], gathered_params[other_rank], rtol=1e-3, atol=1e-5), \ + f"Parameters differ between rank 0 and rank {other_rank} at param_idx={param_idx}. " \ + f"Cross-replica gradient sync failed for stage={zero_stage}, contiguous={contiguous_gradients}, bucket_size={reduce_bucket_size}, reduce_scatter={reduce_scatter}, overlap_comm={overlap_comm}!" + + model.destroy()