Skip to content

Commit 5f8c0a0

Browse files
authored
[Feature] auto-cast optimizers to distributed version (#5746)
* auto-cast optimizers to distributed * fix galore casting * logger --------- Co-authored-by: Edenzzzz <[email protected]>
1 parent 2fc85ab commit 5f8c0a0

File tree

13 files changed

+61
-31
lines changed

13 files changed

+61
-31
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from colossalai.cluster import ProcessGroupMesh
2828
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
2929
from colossalai.interface.optimizer import DistributedOptim
30-
from colossalai.nn.optimizer import DistGaloreAwamW
30+
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
3131
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
3232
from colossalai.pipeline.stage_manager import PipelineStageManager
3333
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
@@ -1179,6 +1179,10 @@ def configure(
11791179
# TODO: Support Galore + ZeRO
11801180
zero_stage = self.zero_stage
11811181
zero_config = deepcopy(self.zero_config)
1182+
1183+
# Replace with distributed implementation if exists
1184+
optimizer = cast_to_distributed(optimizer)
1185+
11821186
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
11831187
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
11841188
zero_config["partition_grad"] = False

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
3434
from colossalai.interface.optimizer import DistributedOptim
35-
from colossalai.nn.optimizer import DistGaloreAwamW
35+
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
3636
from colossalai.quantization import BnbQuantizationConfig, quantize_model
3737
from colossalai.zero import LowLevelZeroOptimizer
3838

@@ -437,6 +437,10 @@ def configure(
437437
zero_stage = self.stage
438438
zero_optim_kwargs = {**self.zero_optim_kwargs}
439439
dp_size = dist.get_world_size()
440+
441+
# Replace with the distributed implementation if exists
442+
optimizer = cast_to_distributed(optimizer)
443+
440444
if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
441445
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
442446
zero_optim_kwargs["partition_grad"] = False

colossalai/nn/optimizer/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from galore_torch import GaLoreAdafactor, GaLoreAdamW
22

3+
from colossalai.logging import get_dist_logger
4+
35
from .came import CAME
46
from .cpu_adam import CPUAdam
57
from .distributed_adafactor import DistributedAdaFactor
@@ -34,3 +36,22 @@
3436
"Adafactor",
3537
"DistributedAdaFactor",
3638
]
39+
40+
optim2DistOptim = {
41+
GaLoreAdamW8bit: DistGaloreAwamW,
42+
Lamb: DistributedLamb,
43+
CAME: DistributedCAME,
44+
Adafactor: DistributedAdaFactor,
45+
}
46+
_logger = get_dist_logger()
47+
48+
49+
def cast_to_distributed(optim):
50+
if optim.__class__ in optim2DistOptim:
51+
_logger.info(f"Converting optimizer {optim.__class__.__name__} to its distributed version.", ranks=[0])
52+
53+
if isinstance(optim, GaLoreAdamW8bit):
54+
return optim2DistOptim[GaLoreAdamW8bit](optim.param_groups, args=optim.args)
55+
return optim2DistOptim[optim.__class__](optim.param_groups)
56+
57+
return optim

colossalai/nn/optimizer/distributed_came.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@ def __init__(
3434
betas=(0.9, 0.999, 0.9999),
3535
weight_decay=0.0,
3636
):
37-
assert lr > 0.0
38-
assert all([0.0 <= beta <= 1.0 for beta in betas])
39-
4037
defaults = dict(
4138
lr=lr,
4239
eps=eps,

colossalai/nn/optimizer/distributed_galore.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,13 @@ class DistGaloreAwamW(DistributedOptim, Optimizer2State):
4343
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
4444
is_paged (`bool`, defaults to `False`):
4545
Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
46+
args (dict, optional): quantization-related arguments. If passed, will override all quantization args above.
4647
"""
4748

4849
def __init__(
4950
self,
5051
params,
51-
lr=1e-3,
52+
lr=1e-2,
5253
betas=(0.9, 0.999),
5354
eps=1e-8,
5455
weight_decay=1e-2,
@@ -57,6 +58,7 @@ def __init__(
5758
percentile_clipping=100,
5859
block_wise=True,
5960
is_paged=False,
61+
args=None,
6062
):
6163
super().__init__(
6264
"adam",
@@ -65,13 +67,14 @@ def __init__(
6567
betas,
6668
eps,
6769
weight_decay,
68-
nbits,
69-
None,
70-
min_8bit_size,
71-
percentile_clipping,
72-
block_wise,
70+
optim_bits=nbits,
71+
args=args,
72+
min_8bit_size=min_8bit_size,
73+
percentile_clipping=percentile_clipping,
74+
block_wise=block_wise,
7375
is_paged=is_paged,
7476
)
77+
7578
self.tp_size = 1
7679
self.dp_size = 1
7780
self.is_dist = {}

colossalai/nn/optimizer/galore.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ class GaLoreAdamW8bit(Optimizer2State):
184184
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
185185
is_paged (`bool`, defaults to `False`):
186186
Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
187+
args (dict, optional): quantization-related arguments. If passed, will override all quantization args above.
187188
Example:
188189
189190
"""
@@ -200,6 +201,7 @@ def __init__(
200201
percentile_clipping=100,
201202
block_wise=True,
202203
is_paged=False,
204+
args=None,
203205
):
204206
super().__init__(
205207
"adam",
@@ -208,11 +210,11 @@ def __init__(
208210
betas,
209211
eps,
210212
weight_decay,
211-
nbits,
212-
None,
213-
min_8bit_size,
214-
percentile_clipping,
215-
block_wise,
213+
optim_bits=nbits,
214+
args=args,
215+
min_8bit_size=min_8bit_size,
216+
percentile_clipping=percentile_clipping,
217+
block_wise=block_wise,
216218
is_paged=is_paged,
217219
)
218220

docs/source/en/features/distributed_optimizers.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ Author: [Wenxuan Tan](https://github.com/Edenzzzz), [Junwen Duan](https://github
99
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962)
1010

1111
## Introduction
12-
Apart from the widely adopted Adam and SGD, many modern optimizers require layer-wise statistics to efficiently update parameters, and are thus not directly applicable to parallel settings where model layers are sharded across multiple devices. We provide optimized distributed implementations with minimal extra communications, and seamless integrations with Tensor Parallel, DDP and ZeRO using plugins.
12+
Apart from the widely adopted Adam and SGD, many modern optimizers require layer-wise statistics to update parameters, and thus aren't directly applicable to settings where model layers are sharded across multiple devices. We provide optimized distributed implementations with minimal extra communications, and seamless integrations with Tensor Parallel, DDP and ZeRO plugins, which automatically uses distributed optimizers with 0 code change.
13+
1314
## Optimizers
1415
Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(NMF) to reduce memory footprint. CAME improves by introducting a confidence matrix to correct NMF. GaLore further reduces memory by projecting gradients into a low-rank space and 8-bit block-wise quantization. Lamb allows huge batch sizes without lossing accuracy via layer-wise adaptive update bounded by the inverse of its Lipschiz constant.
1516

@@ -21,7 +22,7 @@ Adafactor is a first-order Adam variant using Non-negative Matrix Factorization(
2122
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
2223

2324
## Hands-On Practice
24-
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs.
25+
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs. **Note that even if you're not aware of distributed optimizers, the plugins automatically casts yours to the distributed version for convenience.**
2526
### step 1. Import libraries
2627

2728
```python

docs/source/zh-Hans/features/distributed_optimizers.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Author: Wenxuan Tan, Junwen Duan, Renjie Mao
99
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962)
1010

1111
## 介绍
12-
除了广泛采用的Adam和SGD外,许多现代优化器需要逐层统计信息以有效更新参数,因此无法直接应用于模型层在多个设备上分片的并行设置。我们以提供了优化的分布式实现,,并且通过插件与Tensor Parallel、DDP和ZeRO无缝集成。
12+
除了广泛采用的Adam和SGD外,许多现代优化器需要逐层统计信息以有效更新参数,因此无法直接应用于模型层在多个设备上分片的并行设置。我们以提供了优化的分布式实现,,并且通过plugin与Tensor Parallel、DDP和ZeRO无缝集成。
1313
## 优化器
1414
Adafactor 是一种首次采用非负矩阵分解(NMF)的 Adam 变体,用于减少内存占用。CAME 通过引入一个置信度矩阵来改进 NMF 的效果。GaLore 通过将梯度投影到低秩空间,并使用 8 位块状量化进一步减少内存占用。Lamb 允许使用巨大的批量大小而不失准确性,通过按其 Lipschitz 常数的倒数界定的逐层自适应更新实现
1515

@@ -21,7 +21,7 @@ Adafactor 是一种首次采用非负矩阵分解(NMF)的 Adam 变体,用
2121
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
2222

2323
## 使用
24-
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs.
24+
现在我们展示如何使用分布式 Adafactor booster API 结合 Tensor Parallel ZeRO 2。即使您不使用distributed optimizer,plugin 也会自动将optimizer转换为分布式版本以方便使用。
2525
### step 1. 导包
2626

2727
```python
@@ -34,15 +34,13 @@ import torch
3434
```
3535

3636
### step 2. 初始化分布式
37-
We need to initialize distributed environment. For demo purpose, we use `colossal run --nproc_per_node 4`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md)
37+
我们需要先初始化分布式环境. 为了展示, 我们使用 `colossal run --nproc_per_node 4`. 更多初始化方式请参考 [Launch Colossal-AI](../basics/launch_colossalai.md)
3838

3939
```python
4040
colossalai.launch_from_torch()
4141
```
4242

4343
### step 3. 初始化模型和优化器
44-
Build our model. We created an MLP using two Linear Layer.
45-
4644
```python
4745
configuration = LlamaConfig()
4846
model = LlamaModel(configuration).cuda()

tests/test_optimizer/test_dist_adafactor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def exam_bert_test_on_lowlevelzero_plugin(test_config):
552552
sharded_optimizer,
553553
criterion,
554554
booster,
555-
) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, Adafactor, DistributedAdaFactor)
555+
) = build_model_from_low_level_zero_plugin(model_fn, loss_fn, test_config, Adafactor, Adafactor)
556556

557557
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_low_level_zero_plugin(
558558
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster

tests/test_optimizer/test_dist_came.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def exam_bert_test_on_hybrid_plugin(test_config):
416416
sharded_optimizer,
417417
criterion,
418418
booster,
419-
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, CAME, DistributedCAME)
419+
) = build_model_from_hybrid_plugin(model_fn, loss_fn, test_config, CAME, CAME)
420420

421421
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
422422
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster

0 commit comments

Comments
 (0)