Skip to content

Commit a7d95b7

Browse files
authored
[example] add zero1, zero2 example in GPT examples (#2146)
* [example] add zero1 and zero2 for GPT * update readme in gpt example * polish code * change init value * update readme
1 parent 1cce6e3 commit a7d95b7

File tree

5 files changed

+39
-26
lines changed

5 files changed

+39
-26
lines changed

colossalai/zero/sharded_optim/low_level_optim.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ def __init__(
3535
optimizer: Optimizer,
3636

3737
# grad scaler config
38-
initial_scale=2**32,
38+
initial_scale=2**16,
3939
min_scale=1,
4040
growth_factor=2,
4141
backoff_factor=0.5,
42-
growth_interval=1000,
42+
growth_interval=2000,
4343
hysteresis=2,
44-
max_scale: int = 2**32,
44+
max_scale: int = 2**24,
4545

4646
# grad clipping
4747
clip_grad_norm=0.0,

examples/language/gpt/README.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit
1919
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
2020
```
2121

22-
### Install [Colossal-AI v0.1.11rc5](https://colossalai.org/download/) From Official Website
22+
### Install [Colossal-AI v0.1.12](https://colossalai.org/download/) From Official Website
2323

2424
```bash
25-
pip install colossalai==0.1.11rc5+torch1.12cu11.3 -f https://release.colossalai.org
25+
pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org
2626
```
2727

2828
### Install transformers
@@ -31,7 +31,8 @@ pip install colossalai==0.1.11rc5+torch1.12cu11.3 -f https://release.colossalai.
3131
pip install transformers
3232
```
3333

34-
This is just an example that we download PyTorch=1.12.0, CUDA=11.6 and colossalai=0.1.11rc5+torch1.12cu11.3. You can download another version of PyTorch and its corresponding ColossalAI version. Just make sure that the version of ColossalAI is at least 0.1.10, PyTorch is at least 1.8.1 and transformers is at least 4.231.
34+
This is just an example that we download PyTorch=1.12.0, CUDA=11.6 and colossalai=0.1.12+torch1.12cu11.3. You can download another version of PyTorch and its corresponding ColossalAI version. Just make sure that the version of ColossalAI is at least 0.1.10, PyTorch is at least 1.8.1 and transformers is at least 4.231.
35+
If you want to test ZeRO1 and ZeRO2 in Colossal-AI, you need to ensure Colossal-AI>=0.1.12.
3536

3637
## Dataset
3738

@@ -48,5 +49,7 @@ bash run.sh
4849
The `train_gpt_demo.py` provides three distributed plans, you can choose the plan you want in `run.sh`. The Colossal-AI leverages Tensor Parallel and Gemini + ZeRO DDP.
4950

5051
- Colossal-AI
51-
- PyTorch DDP
52-
- ZeRO
52+
- ZeRO1 (Colossal-AI)
53+
- ZeRO2 (Colossal-AI)
54+
- Pytorch DDP
55+
- Pytorch ZeRO
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
colossalai >= 0.1.10
1+
colossalai >= 0.1.12
22
torch >= 1.8.1
33
transformers >= 4.231

examples/language/gpt/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# distplan in ["colossalai", "zero", "ddp"]
1+
# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"]
22
export DISTPAN="colossalai"
33

44
# The following options only valid when DISTPAN="colossalai"

examples/language/gpt/train_gpt_demo.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.nn as nn
77
from packaging import version
88
from torch.nn.parallel import DistributedDataParallel as DDP
9+
from transformers import GPT2Config, GPT2LMHeadModel
910

1011
import colossalai
1112
from colossalai.logging import disable_existing_loggers, get_dist_logger
@@ -16,7 +17,7 @@
1617
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
1718
from colossalai.utils import get_current_device
1819
from colossalai.utils.model.colo_init_context import ColoInitContext
19-
from transformers import GPT2Config, GPT2LMHeadModel
20+
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
2021

2122

2223
def parse_args():
@@ -25,7 +26,7 @@ def parse_args():
2526
"--distplan",
2627
type=str,
2728
default='colossalai',
28-
help="The distributed plan [colossalai, ddp, zero].",
29+
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
2930
)
3031
parser.add_argument(
3132
"--tp_degree",
@@ -202,6 +203,9 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
202203
def main():
203204
args = parse_args()
204205

206+
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
207+
raise TypeError(f"{args.distplan} is error")
208+
205209
BATCH_SIZE = 8
206210
SEQ_LEN = 1024
207211
VOCAB_SIZE = 50257
@@ -237,19 +241,24 @@ def main():
237241
# optimizer = HybridAdam(model.parameters(), lr=1e-3)
238242
# optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
239243
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
240-
241-
elif args.distplan == "ddp":
244+
else:
242245
model = gpt2_medium(checkpoint=True).cuda()
243-
ddp_model = DDP(model)
244-
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)
245246

246-
elif args.distplan == "zero":
247-
from torch.distributed.optim import ZeroRedundancyOptimizer
248-
model = gpt2_medium(checkpoint=True).cuda()
249-
ddp_model = DDP(model)
250-
optimizer = ZeroRedundancyOptimizer(ddp_model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
251-
else:
252-
raise TypeError(f"{args.distplan} is error")
247+
if args.distplan.startswith("torch"):
248+
model = DDP(model)
249+
if args.distplan.endswith("ddp"):
250+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
251+
elif args.distplan.endswith("zero"):
252+
from torch.distributed.optim import ZeroRedundancyOptimizer
253+
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
254+
elif args.distplan.startswith("zero"):
255+
partition_flag = args.distplan == "zero2"
256+
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
257+
optimizer = LowLevelZeroOptimizer(optimizer,
258+
overlap_communication=True,
259+
partition_grad=partition_flag,
260+
verbose=True)
261+
# notice that the model is still in fp32
253262

254263
numel = sum([p.numel() for p in model.parameters()])
255264
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
@@ -265,12 +274,13 @@ def main():
265274
outputs = model(input_ids, attn_mask)
266275
loss = criterion(outputs, input_ids)
267276
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
268-
if args.distplan == "colossalai":
277+
if args.distplan in ["colossalai", "zero1", "zero2"]:
269278
optimizer.backward(loss)
270-
elif args.distplan in ["ddp", "zero"]:
279+
elif args.distplan in ["torch_ddp", "torch_zero"]:
271280
loss.backward()
272-
273281
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
282+
if args.distplan in ["zero1", "zero2"]:
283+
optimizer.sync_grad()
274284
optimizer.step()
275285
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
276286
step_time = time() - start

0 commit comments

Comments
 (0)