Skip to content

Commit 77ec773

Browse files
authored
[zero]remove registered gradients hooks (#5687)
* remove registered hooks fix fix fix zero fix fix fix fix fix zero fix zero fix fix fix * fix fix fix
1 parent c25f83c commit 77ec773

File tree

7 files changed

+256
-167
lines changed

7 files changed

+256
-167
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
735735
# Get all working gradients and gradients to be synchronized.
736736
all_working_grads = _get_all_working_grads()
737737
grads_to_sync = _get_grads_to_sync(all_working_grads)
738-
if self.require_grad_sync and grads_to_sync is not None:
738+
if self._grad_store.require_grad_sync and grads_to_sync is not None:
739739
# Synchronize sequence parallelism gradients if required.
740740
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
741741
else:
@@ -759,7 +759,7 @@ def backward(self, loss, retain_graph=False):
759759
# Call the superclass backward method to compute gradients.
760760
super().backward(loss, retain_graph)
761761

762-
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
762+
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
763763
# If gradient synchronization is required, sync sequence parallelism gradients.
764764
self._sync_sp_grads()
765765
else:
@@ -784,7 +784,7 @@ def backward_by_grad(self, tensor, grad):
784784
# Call the superclass backward_by_grad method to compute gradients.
785785
super().backward_by_grad(tensor, grad)
786786

787-
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
787+
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
788788
# If gradient synchronization is required, sync sequence parallelism gradients.
789789
self._sync_sp_grads()
790790
else:
@@ -1272,7 +1272,7 @@ def execute_pipeline(
12721272

12731273
# run with gradients accumulation
12741274
if model.require_grad_sync == False or (
1275-
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
1275+
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False
12761276
):
12771277
return outputs
12781278

colossalai/zero/low_level/bookkeeping/base_store.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ class BaseStore:
66
def __init__(self, torch_pg: ProcessGroup):
77
self._world_size = dist.get_world_size(group=torch_pg)
88
self._local_rank = dist.get_rank(group=torch_pg)
9+
self.torch_pg = torch_pg
910

1011
@property
1112
def world_size(self):

colossalai/zero/low_level/bookkeeping/bucket_store.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,43 @@
1-
from typing import Dict
1+
from typing import Dict, Optional
22

33
import torch
4+
import torch.distributed as dist
45
from torch import Tensor
56
from torch._utils import _flatten_dense_tensors
67
from torch.distributed import ProcessGroup
78

9+
from colossalai.accelerator import get_accelerator
10+
811
from .base_store import BaseStore
912

1013

1114
class BucketStore(BaseStore):
12-
def __init__(self, torch_pg: ProcessGroup):
15+
def __init__(
16+
self,
17+
torch_pg: ProcessGroup,
18+
reduce_bucket_size: int,
19+
overlap_communication: bool,
20+
communication_dtype: Optional[torch.dtype] = None,
21+
moe_extra_dp_process_group: ProcessGroup = None,
22+
):
1323
super().__init__(torch_pg)
24+
self.reduce_bucket_size = reduce_bucket_size
25+
# communication params
26+
self._overlap_communication = overlap_communication
27+
self._communication_dtype = communication_dtype
28+
if self._overlap_communication:
29+
self.comm_stream = get_accelerator().Stream()
30+
self.zero_local_rank = dist.get_rank(group=self.torch_pg)
31+
self.zero_world_size = dist.get_world_size(group=self.torch_pg)
32+
# extra dp
33+
# This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
34+
# Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
35+
# Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
36+
# And moe working and master param are split by extra dp pg.
37+
self.moe_extra_dp_pg = moe_extra_dp_process_group
38+
if self.moe_extra_dp_pg is not None:
39+
self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
40+
self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
1441
self.reset_all()
1542

1643
def reset_all(self) -> None:

colossalai/zero/low_level/bookkeeping/gradient_store.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
class GradientStore(BaseStore):
9-
def __init__(self, *args, partition_grad: bool = False):
9+
def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool = True):
1010
super().__init__(*args)
1111
"""
1212
self._grads_of_params mapping the parameter and its gradient slices
@@ -18,9 +18,12 @@ def __init__(self, *args, partition_grad: bool = False):
1818
}
1919
"""
2020
self._grads_of_params = dict()
21-
# for zero2, it's `param_id: [grad_local_rank]`
21+
# stage 2
22+
self._partition_grads = partition_grad
23+
# grad accumulation
24+
self.require_grad_sync = require_grad_sync
2225
self._working_index = 0 if partition_grad else self._local_rank
23-
26+
# for zero2, it's `param_id: [grad_local_rank]`
2427
self.grad_to_param_mapping = dict()
2528

2629
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:

0 commit comments

Comments
 (0)