Skip to content

Commit 8c9717a

Browse files
authored
[AutoParallel] add sharding opt config (#6124)
1 parent ecb2f66 commit 8c9717a

File tree

5 files changed

+79
-3
lines changed

5 files changed

+79
-3
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
_base_: ./pretrain_gpt_base.yaml
2+
3+
Global:
4+
global_batch_size:
5+
local_batch_size: 1
6+
micro_batch_size: 1
7+
8+
9+
Model:
10+
vocab_size: 50304
11+
hidden_size: 5120
12+
num_layers: 40
13+
num_attention_heads: 40
14+
ffn_hidden_size:
15+
hidden_dropout_prob: 0.1
16+
attention_probs_dropout_prob: 0.1
17+
max_position_embeddings: 1024
18+
type_vocab_size: 16
19+
initializer_range: 0.02
20+
fuse_attn_qkv: True
21+
use_recompute: True
22+
recompute_granularity:
23+
no_recompute_layers:
24+
25+
26+
Distributed:
27+
dp_degree:
28+
mp_degree: 1
29+
pp_degree: 1
30+
sharding:
31+
sharding_degree: 8
32+
sharding_stage: 3
33+
reduce_overlap: True
34+
broadcast_overlap: True
35+
param_comm_stream_num: 3
36+
grad_comm_stream_num: 3
37+
param_bucket_size_numel: 210355872
38+
grad_bucket_size_numel: 210355872
39+
enable_hierarchical_comm: False

model_zoo/gpt-3/ppfleetx/models/language_model/gpt/auto/auto_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(
9393
self.ipp = ipp
9494

9595
self.head_dim = embed_dim // num_heads
96-
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
96+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim[{}] must be divisible by num_heads[{}]".format(self.embed_dim, num_heads)
9797

9898
if self.fuse_attn_qkv:
9999
assert self.kdim == embed_dim
@@ -290,7 +290,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, use_cache=False,
290290
new_caches = []
291291

292292
for i, mod in enumerate(self.layers):
293-
mod = auto.shard_op(mod, auto_env.get_mesh()[mod.ipp])
293+
ipp = mod.ipp
294+
mod = auto.shard_op(mod, auto_env.get_mesh()[ipp])
294295

295296
if cache is None:
296297
if use_cache:
@@ -305,6 +306,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, use_cache=False,
305306
output, new_cache = mod(output, memory, tgt_mask=tgt_mask, use_cache=use_cache, cache=cache[i])
306307
new_caches.append(new_cache)
307308

309+
auto.shard_tensor(output, auto_env.get_mesh()[ipp], [auto_env.get_mesh().dp_dim, None, None])
310+
308311
if self.norm is not None:
309312
output = self.norm(output)
310313
return output if use_cache is False else (output, new_caches)

model_zoo/gpt-3/ppfleetx/utils/auto_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def process_dist_configs(config):
4646
sharding_config = configs["sharding"]
4747
sharding_degree = sharding_config.setdefault("sharding_degree", 1)
4848
sharding_config.setdefault("sharding_stage", 2)
49+
sharding_config.setdefault("reduce_overlap", False)
50+
sharding_config.setdefault("broadcast_overlap", False)
4951

5052
other_degree = mp_degree * pp_degree
5153

@@ -184,6 +186,12 @@ def process_strategy(config):
184186
sharding.enable = sharding_cfg.get("sharding_degree", 1) > 1
185187
sharding.degree = sharding_cfg.get("sharding_degree", 1)
186188
sharding.stage = sharding_cfg.get("sharding_stage", 1)
189+
sharding.enable_overlap = sharding_cfg.get("reduce_overlap", False) and sharding_cfg.get("broadcast_overlap", False)
190+
sharding.param_comm_stream_num = sharding_cfg.get("param_comm_stream_num", 1)
191+
sharding.grad_comm_stream_num = sharding_cfg.get("grad_comm_stream_num", 1)
192+
sharding.param_bucket_size_numel = sharding_cfg.get("param_bucket_size_numel", 1)
193+
sharding.grad_bucket_size_numel = sharding_cfg.get("grad_bucket_size_numel", 1)
194+
sharding.enable_hierarchical_comm = sharding_cfg.get("enable_hierarchical_comm", False)
187195

188196
pp_degree = config["Distributed"]["pp_degree"]
189197
accumulate_steps = config.Engine.get("accumulate_steps", 1)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#! /bin/bash
2+
# Runs the "1.3B" parameter model
3+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
log_dir=log_auto
18+
rm -rf $log_dir
19+
20+
# 10B+sharding8 run_pretrain
21+
python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3,4,5,6,7" \
22+
./tools/auto.py \
23+
-c ./ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_13B_sharding8.yaml \
24+
-o Engine.max_steps=1000 \
25+
-o Engine.logging_freq=1 \
26+
-o Engine.verbose=3

model_zoo/gpt-3/projects/gpt/auto_gpt_6.7B_sharding16.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ rm -rf $log_dir
2020
# 6.7B+sharding16 run_pretrain
2121
python -m paddle.distributed.launch --log_dir $log_dir --devices "0,1,2,3,4,5,6,7" \
2222
./tools/auto.py \
23-
-c ./ppfleetx/configs/nlp/gp/auto/pretrain_gpt_6.7B_sharding16.yaml
23+
-c ./ppfleetx/configs/nlp/gpt/auto/pretrain_gpt_6.7B_sharding16.yaml

0 commit comments

Comments
 (0)