Skip to content

Commit 43995ee

Browse files
Edenzzzzchongqichuizi875pre-commit-ci[bot]duanjunwenver217
authored
[Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694)
* [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better --------- Co-authored-by: Edenzzzz <[email protected]> * [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576) Co-authored-by: Edenzzzz <[email protected]> * [optim] add distributed came (#5526) * test CAME under LowLevelZeroOptimizer wrapper * test CAME TP row and col pass * test CAME zero pass * came zero add master and worker param id convert * came zero test pass * came zero test pass * test distributed came passed * reform code, Modify some expressions and add comments * minor fix of test came * minor fix of dist_came and test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix of dist_came and test * rebase dist-optim * rebase dist-optim * fix remaining comments * add test dist came using booster api --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [optim] Distributed Adafactor (#5484) * [feature] solve conflict; update optimizer readme; * [feature] update optimize readme; * [fix] fix testcase; * [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel); * [feature] Add transformers_bert model zoo in testcase; * [feature] add user documentation to docs/source/feature. * [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam; * [feature] modify user documentation; * [fix] fix readme format issue; * [fix] add zero=0 in testcase; cached augment in dict; * [fix] fix percision issue; * [feature] add distributed rms; * [feature] remove useless comment in testcase; * [fix] Remove useless test; open zero test; remove fp16 test in bert exam; * [feature] Extract distributed rms function; * [feature] add booster + lowlevelzeroPlugin in test; * [feature] add Start_with_booster_API case in md; add Supporting Information in md; * [fix] Also remove state movement in base adafactor; * [feature] extract factor function; * [feature] add LowLevelZeroPlugin test; * [fix] add tp=False and zero=True in logic; * [fix] fix use zero logic; * [feature] add row residue logic in column parallel factor; * [feature] add check optim state func; * [feature] Remove duplicate logic; * [feature] update optim state check func and percision test bug; * [fix] update/fix optim state; Still exist percision issue; * [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info; * [feature] removed print & comments in utils; * [feature] uodate Readme; * [feature] add LowLevelZeroPlugin test with Bert model zoo; * [fix] fix logic in _rms; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] remove comments in testcase; * [feature] add zh-Han Readme; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676) * [feature] daily update; * [fix] fix dist came; * [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; * [fix] open rms; fix low level zero test; fix dist came test function name; * [fix] remove redundant test; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better * update comments * add initial distributed galore * add initial distributed galore * add galore set param utils; change setup_distributed interface * projected grad precision passed * basic precision tests passed * tests passed; located svd precision issue in fwd-bwd; banned these tests * Plugin DP + TP tests passed * move get_shard_dim to d_tensor * add comments * remove useless files * remove useless files * fix zero typo * improve interface * remove moe changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import * fix deepcopy * update came & adafactor to main * fix param map * fix typo --------- Co-authored-by: Edenzzzz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692) Co-authored-by: Edenzzzz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: Edenzzzz <[email protected]> Co-authored-by: chongqichuizi875 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <[email protected]> Co-authored-by: Hongxin Liu <[email protected]>
1 parent 393c8f5 commit 43995ee

30 files changed

+4821
-42
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import ctypes
22
import random
33
import warnings
4+
from collections import defaultdict
45
from contextlib import contextmanager
6+
from copy import deepcopy
57
from functools import partial
68
from types import MethodType
79
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
@@ -24,6 +26,8 @@
2426
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
2527
from colossalai.cluster import ProcessGroupMesh
2628
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
29+
from colossalai.interface.optimizer import DistributedOptim
30+
from colossalai.nn.optimizer import DistGaloreAwamW
2731
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
2832
from colossalai.pipeline.stage_manager import PipelineStageManager
2933
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
@@ -1171,6 +1175,15 @@ def configure(
11711175
lr_scheduler: Optional[LRScheduler] = None,
11721176
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
11731177
param_info = get_param_info(optimizer)
1178+
1179+
# TODO: Support Galore + ZeRO
1180+
zero_stage = self.zero_stage
1181+
zero_config = deepcopy(self.zero_config)
1182+
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
1183+
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
1184+
zero_config["partition_grad"] = False
1185+
zero_stage = 0
1186+
11741187
if not isinstance(model, ModelWrapper):
11751188
use_ddp = (self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0) or (
11761189
self.dp_size == 1
@@ -1194,7 +1207,8 @@ def configure(
11941207
custom_policy=self.custom_policy,
11951208
)
11961209
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
1197-
if self.zero_stage == 0:
1210+
if zero_stage == 0:
1211+
is_zero = False
11981212
if self.precision in ["fp16", "bf16"]:
11991213
optimizer = HybridParallelAMPOptimizer(
12001214
optimizer,
@@ -1218,11 +1232,11 @@ def configure(
12181232
tp_process_group=self.tp_group,
12191233
)
12201234
else:
1221-
zero_dp_size = dist.get_world_size(dp_group)
1222-
if zero_dp_size == 1:
1235+
is_zero = self.dp_size > 1
1236+
if self.dp_size == 1:
12231237
warnings.warn(
12241238
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
1225-
"If you are not intended to use cpu_offload, please consider set zero_stage=0."
1239+
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
12261240
)
12271241

12281242
assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
@@ -1236,11 +1250,19 @@ def configure(
12361250
pp_process_group=self.pp_group,
12371251
verbose=True,
12381252
clip_grad_norm=self.max_norm,
1239-
**self.zero_config,
1253+
**zero_config,
12401254
**self.amp_config,
12411255
)
12421256
# inject update_master_params
12431257
model.update_master_params = MethodType(optimizer.update_master_params, model)
1258+
1259+
# Setup optimizers that require global states
1260+
optim = optimizer.optim
1261+
if isinstance(optim, DistributedOptim):
1262+
shard_to_param = optimizer.get_master_to_working_map() if is_zero else {}
1263+
padding_map = optimizer.get_param_padding_map() if is_zero else defaultdict(int)
1264+
optim.setup_distributed(self.tp_group, self.dp_group, shard_to_param, padding_map, is_zero)
1265+
12441266
return model, optimizer, criterion, dataloader, lr_scheduler
12451267

12461268
def execute_pipeline(

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from typing import Callable, Dict, Iterator, List, Optional, Tuple
99

1010
import torch
11+
import torch.distributed
12+
import torch.distributed as dist
1113
import torch.nn as nn
14+
from torch.distributed.distributed_c10d import _get_default_group
1215
from torch.nn import Parameter
1316
from torch.optim import Optimizer
1417
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
@@ -28,6 +31,8 @@
2831
sharded_optimizer_loading_epilogue,
2932
)
3033
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
34+
from colossalai.interface.optimizer import DistributedOptim
35+
from colossalai.nn.optimizer import DistGaloreAwamW
3136
from colossalai.quantization import BnbQuantizationConfig, quantize_model
3237
from colossalai.zero import LowLevelZeroOptimizer
3338

@@ -428,13 +433,31 @@ def configure(
428433
if not isinstance(model, ModelWrapper):
429434
model = LowLevelZeroModel(model, self.precision)
430435

436+
# TODO: Support Galore + ZeRO
437+
zero_stage = self.stage
438+
zero_optim_kwargs = {**self.zero_optim_kwargs}
439+
dp_size = dist.get_world_size()
440+
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
441+
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
442+
zero_optim_kwargs["partition_grad"] = False
443+
zero_stage = 0
444+
431445
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
432446
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
433-
optimizer, **self.zero_optim_kwargs, verbose=self.verbose
447+
optimizer, **zero_optim_kwargs, verbose=self.verbose
434448
)
435449
# inject update_master_params
436450
model.update_master_params = MethodType(optimizer.update_master_params, model)
437451

452+
# Setup optimizers that require global states
453+
optim = optimizer.optim
454+
is_zero = dp_size > 1 and zero_stage > 0
455+
dp_group = _get_default_group() # Use the whole world
456+
if isinstance(optim, DistributedOptim):
457+
shard_to_param = optimizer.get_master_to_working_map()
458+
padding_map = optimizer.get_param_padding_map()
459+
optim.setup_distributed(None, dp_group, shard_to_param, padding_map, is_zero)
460+
438461
return model, optimizer, criterion, dataloader, lr_scheduler
439462

440463
def control_checkpoint_io(self) -> bool:

colossalai/cluster/process_group_mesh.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ class ProcessGroupMesh:
3838

3939
def __init__(self, *size: int) -> None:
4040
assert dist.is_initialized(), "Please initialize torch.distributed first."
41-
assert prod(size) == dist.get_world_size(), "The product of the size must be equal to the world size."
41+
world_size = dist.get_world_size()
42+
prod_size = prod(size)
43+
assert (
44+
prod_size == world_size
45+
), f"The product of the size({prod_size}) must be equal to the world size({world_size})."
46+
4247
self._shape = size
4348
self._rank = dist.get_rank()
4449
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)

colossalai/device/device_mesh.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,8 @@ def _init_global_to_logical_rank_mapping(
306306
# index means the local rank in the current axis
307307
# inner_tensor refers to the processes with the same local rank
308308

309-
if inner_tensor.numel() == 1:
310-
# if the inner_tensor only has one element, it means that
311-
# it already reaches the last axis
309+
if inner_tensor.dim() == 0:
310+
# if the inner_tensor already reaches the last axis,
312311
# we append its local_rank in the last axis to the index_list
313312
# and assign to the mapping
314313
# the value of the mapping is the the local rank at the indexed axis of the device mesh
@@ -459,6 +458,7 @@ def _collate_global_ranks_in_same_process_group(self, global_rank):
459458

460459
# replace the local rank in the given dimension with the
461460
# local rank of the current process iterated
461+
462462
process_coordinates[dim] = _local_rank
463463
processes_in_the_same_process_group[dim].append(process_coordinates)
464464

colossalai/interface/optimizer.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Union
1+
from typing import Dict, Optional, Union
22

33
import torch
4+
import torch.distributed as dist
45
import torch.nn as nn
56
from torch import Tensor
67
from torch.optim import Optimizer
@@ -133,3 +134,25 @@ def unwrap(self):
133134
Unwrap the optimizer for checkpoint saving/loading.
134135
"""
135136
return self.optim
137+
138+
139+
class DistributedOptim(Optimizer):
140+
def setup_distributed(
141+
self,
142+
tp_group: Optional[dist.ProcessGroup] = None,
143+
dp_group: Optional[dist.ProcessGroup] = None,
144+
shard_to_working_param: Optional[Dict] = {},
145+
padding_map: Optional[Dict] = None,
146+
is_zero: Optional[bool] = False,
147+
):
148+
"""Assign process groups for TP and ZeRO 2.
149+
Arguments:
150+
tp_group (dist.ProcessGroup): Tensor Parallel process group
151+
dp_group (dist.ProcessGroup): ZeRO stage 2 process group
152+
shard_to_working_param (Dict): ZeRO stage 2 feeds the optimizer a sharded param view to match grad shape.
153+
This maps from id(view) to model params used in forward & backward.
154+
padding_map (Dict): Per-param padding from ZeRO stage 2
155+
is_zero (bool): Whether to use ZeRO stage 2.
156+
"""
157+
158+
raise NotImplementedError("setup_distributed for TP/DP isn't supported by this optimizer yet!")
Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,36 @@
1+
from galore_torch import GaLoreAdafactor, GaLoreAdamW
2+
3+
from .came import CAME
14
from .cpu_adam import CPUAdam
5+
from .distributed_adafactor import DistributedAdaFactor
6+
from .distributed_came import DistributedCAME
7+
from .distributed_galore import DistGaloreAwamW
8+
from .distributed_lamb import DistributedLamb
29
from .fused_adam import FusedAdam
310
from .fused_lamb import FusedLAMB
411
from .fused_sgd import FusedSGD
12+
from .galore import GaLoreAdamW8bit
513
from .hybrid_adam import HybridAdam
614
from .lamb import Lamb
715
from .lars import Lars
816

9-
__all__ = ["FusedLAMB", "FusedAdam", "FusedSGD", "Lamb", "Lars", "CPUAdam", "HybridAdam"]
17+
from .adafactor import Adafactor # noqa
18+
19+
__all__ = [
20+
"FusedLAMB",
21+
"FusedAdam",
22+
"FusedSGD",
23+
"Lamb",
24+
"Lars",
25+
"CPUAdam",
26+
"HybridAdam",
27+
"DistributedLamb",
28+
"DistGaloreAwamW",
29+
"GaLoreAdamW",
30+
"GaLoreAdafactor",
31+
"GaLoreAdamW8bit",
32+
"CAME",
33+
"DistributedCAME",
34+
"Adafactor",
35+
"DistributedAdaFactor",
36+
]

0 commit comments

Comments
 (0)