Skip to content

Commit 75735a8

Browse files
authored
[MoE] Add MoE model pretrain and support MoE+DP (#2385)
* support moe+dp
1 parent f982df8 commit 75735a8

File tree

14 files changed

+3450
-0
lines changed

14 files changed

+3450
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../model_zoo/ernie-1.0/data_tools
Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
17+
import paddle
18+
from paddlenlp.utils.log import logger
19+
20+
21+
def process_batch_size(args):
22+
if args.global_batch_size is None and args.local_batch_size is None:
23+
raise ValueError("global_batch_size or local_batch_size should be set.")
24+
elif args.global_batch_size is not None and args.local_batch_size is not None:
25+
assert args.global_batch_size // args.local_batch_size == args.dp_degree, \
26+
"global_batch_size[{}] should be divided by local_batch_size[{}] when dp_degree is [{}]"\
27+
.format(args.global_batch_size, args.local_batch_size, args.dp_degree)
28+
elif args.global_batch_size is not None and args.local_batch_size is None:
29+
args.local_batch_size = args.global_batch_size // args.dp_degree
30+
else:
31+
args.global_batch_size = args.local_batch_size * args.dp_degree
32+
assert args.local_batch_size % args.micro_batch_size == 0
33+
34+
35+
def str2bool(v):
36+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
37+
return True
38+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
39+
return False
40+
else:
41+
raise argparse.ArgumentTypeError('Unsupported value encountered.')
42+
43+
44+
def parse_args(MODEL_CLASSES):
45+
parser = argparse.ArgumentParser()
46+
parser.add_argument(
47+
"--model_type",
48+
default=None,
49+
type=str,
50+
required=True,
51+
help="Model type selected in the list: " +
52+
", ".join(MODEL_CLASSES.keys()), )
53+
parser.add_argument(
54+
"--model_name_or_path",
55+
default=None,
56+
type=str,
57+
required=True,
58+
help="Path to pre-trained model or shortcut name selected in the list: "
59+
+ ", ".join(
60+
sum([
61+
list(classes[-1].pretrained_init_configuration.keys())
62+
for classes in MODEL_CLASSES.values()
63+
], [])), )
64+
65+
# Train I/O config
66+
parser.add_argument(
67+
"--input_dir",
68+
default=None,
69+
type=str,
70+
required=True,
71+
help="The input directory where the data will be read from.", )
72+
parser.add_argument(
73+
"--output_dir",
74+
default=None,
75+
type=str,
76+
required=True,
77+
help="The output directory where the training logs and checkpoints will be written."
78+
)
79+
parser.add_argument(
80+
"--split",
81+
type=str,
82+
default='949,50,1',
83+
help="Train/valid/test data split.")
84+
85+
parser.add_argument(
86+
"--max_seq_len", type=int, default=1024, help="Max sequence length.")
87+
88+
parser.add_argument(
89+
"--global_batch_size",
90+
default=None,
91+
type=int,
92+
help="Global batch size for all training process. None for not check the size is valid. If we only use data parallelism, it should be device_num * micro_batch_size."
93+
)
94+
95+
parser.add_argument(
96+
"--local_batch_size",
97+
default=None,
98+
type=int,
99+
help="Global batch size for all training process. None for not check the size is valid. If we only use data parallelism, it should be device_num * micro_batch_size."
100+
)
101+
102+
parser.add_argument(
103+
"--micro_batch_size",
104+
default=8,
105+
type=int,
106+
help="Batch size per device for one step training.", )
107+
108+
# Default training config
109+
parser.add_argument(
110+
"--weight_decay",
111+
default=0.0,
112+
type=float,
113+
help="Weight decay if we apply some.")
114+
parser.add_argument(
115+
"--grad_clip",
116+
default=0.0,
117+
type=float,
118+
help="Grad clip for the parameter.")
119+
parser.add_argument(
120+
"--max_lr",
121+
default=0.00015,
122+
type=float,
123+
help="The initial max learning rate for Adam.")
124+
parser.add_argument(
125+
"--min_lr",
126+
default=1e-5,
127+
type=float,
128+
help="The initial min learning rate for Adam.")
129+
parser.add_argument(
130+
"--warmup_rate",
131+
default=0.01,
132+
type=float,
133+
help="Linear warmup over warmup_steps for learing rate.")
134+
135+
# Adam optimizer config
136+
parser.add_argument(
137+
"--adam_beta1",
138+
default=0.9,
139+
type=float,
140+
help="The beta1 for Adam optimizer. The exponential decay rate for the 1st moment estimates."
141+
)
142+
parser.add_argument(
143+
"--adam_beta2",
144+
default=0.999,
145+
type=float,
146+
help="The bate2 for Adam optimizer. The exponential decay rate for the 2nd moment estimates."
147+
)
148+
parser.add_argument(
149+
"--adam_epsilon",
150+
default=1e-8,
151+
type=float,
152+
help="Epsilon for Adam optimizer.")
153+
154+
# Training steps config
155+
parser.add_argument(
156+
"--num_train_epochs",
157+
default=1,
158+
type=int,
159+
help="Total number of training epochs to perform.", )
160+
parser.add_argument(
161+
"--max_steps",
162+
default=500000,
163+
type=int,
164+
help="If > 0: set total number of training steps to perform. Override num_train_epochs."
165+
)
166+
parser.add_argument(
167+
"--save_steps",
168+
type=int,
169+
default=500,
170+
help="Save checkpoint every X updates steps.")
171+
parser.add_argument(
172+
"--decay_steps",
173+
default=360000,
174+
type=int,
175+
help="The steps use to control the learing rate. If the step > decay_steps, will use the min_lr."
176+
)
177+
parser.add_argument(
178+
"--logging_freq",
179+
type=int,
180+
default=1,
181+
help="Log every X updates steps.")
182+
parser.add_argument(
183+
"--eval_freq",
184+
type=int,
185+
default=500,
186+
help="Evaluate for every X updates steps.")
187+
parser.add_argument(
188+
"--eval_iters",
189+
type=int,
190+
default=10,
191+
help="Evaluate the model use X steps data.")
192+
193+
# Config for 4D Parallelism
194+
195+
parser.add_argument(
196+
"--sharding_degree",
197+
type=int,
198+
default=1,
199+
help="Sharding degree. Share the parameters to many cards.")
200+
201+
parser.add_argument(
202+
"--dp_degree", type=int, default=1, help="Data Parallelism degree.")
203+
parser.add_argument(
204+
"--mp_degree",
205+
type=int,
206+
default=1,
207+
help="Model Parallelism degree. Spliting the linear layers to many cards."
208+
)
209+
parser.add_argument(
210+
"--pp_degree",
211+
type=int,
212+
default=1,
213+
help="Pipeline Parallelism degree. Spliting the the model layers to different parts."
214+
)
215+
parser.add_argument(
216+
"--use_recompute",
217+
type=str2bool,
218+
nargs='?',
219+
const=False,
220+
help="Using the recompute to save the memory.")
221+
222+
parser.add_argument(
223+
"--recompute_partition",
224+
type=str2bool,
225+
nargs='?',
226+
const=False,
227+
help="use recompute_partition to support mp partition when use_recompute is True ."
228+
)
229+
230+
parser.add_argument(
231+
"--recompute_offload",
232+
type=str2bool,
233+
nargs='?',
234+
const=False,
235+
help="use recompute_offload to save the memory by offload when use_recompute is True ."
236+
)
237+
238+
parser.add_argument(
239+
"--resume_dir",
240+
default="",
241+
type=str,
242+
required=True,
243+
help="The resume directory where the checkpoint will be resume.")
244+
245+
# Pure FP16 config
246+
parser.add_argument(
247+
"--use_pure_fp16",
248+
type=str2bool,
249+
nargs='?',
250+
const=False,
251+
help="Enable pure fp16 precision training.")
252+
253+
parser.add_argument(
254+
"--scale_loss",
255+
type=float,
256+
default=32768,
257+
help="The value of scale_loss for fp16. This is only used for AMP training."
258+
)
259+
260+
parser.add_argument(
261+
"--hidden_dropout_prob",
262+
type=float,
263+
default=0.1,
264+
help="The hidden dropout prob.")
265+
266+
parser.add_argument(
267+
"--attention_probs_dropout_prob",
268+
type=float,
269+
default=0.1,
270+
help="The attention probs dropout prob.")
271+
272+
# MOE config
273+
parser.add_argument(
274+
"--num_experts",
275+
type=int,
276+
default=1,
277+
help="number of experts per worker")
278+
279+
parser.add_argument(
280+
"--top_k", type=int, default=2, help="top_k for moe gate")
281+
282+
parser.add_argument(
283+
"--expert_mode",
284+
type=str2bool,
285+
nargs='?',
286+
const=False,
287+
help="Enable Moe mode.")
288+
289+
parser.add_argument(
290+
"--balance_loss_weight",
291+
default=1.0,
292+
type=float,
293+
help="The auxiliary loss generated by gate strategy to help balance experts."
294+
)
295+
296+
parser.add_argument(
297+
"--gate",
298+
type=str,
299+
default="gshard",
300+
choices=["naive", "gshard", "switch"],
301+
help="select naive, gshard, switch gate strategy.")
302+
303+
# Other config
304+
parser.add_argument(
305+
"--seed", type=int, default=1234, help="Random seed for initialization")
306+
parser.add_argument(
307+
"--check_accuracy",
308+
type=str2bool,
309+
nargs='?',
310+
const=False,
311+
help="Check accuracy for training process.")
312+
parser.add_argument(
313+
"--device",
314+
type=str,
315+
default="gpu",
316+
choices=["cpu", "gpu", "xpu"],
317+
help="select cpu, gpu, xpu devices.")
318+
parser.add_argument(
319+
"--lr_decay_style",
320+
type=str,
321+
default="cosine",
322+
choices=["cosine", "none"],
323+
help="Learning rate decay style.")
324+
325+
args = parser.parse_args()
326+
args.test_iters = args.eval_iters * 10
327+
328+
# process batch size
329+
process_batch_size(args)
330+
331+
if args.check_accuracy:
332+
if args.hidden_dropout_prob != 0:
333+
args.hidden_dropout_prob = .0
334+
logger.warning(
335+
"The hidden_dropout_prob should set to 0 for accuracy checking.")
336+
if args.attention_probs_dropout_prob != 0:
337+
args.attention_probs_dropout_prob = .0
338+
logger.warning(
339+
"The attention_probs_dropout_prob should set to 0 for accuracy checking."
340+
)
341+
342+
logger.info('{:20}:{}'.format("paddle commit id", paddle.version.commit))
343+
for arg in vars(args):
344+
logger.info('{:20}:{}'.format(arg, getattr(args, arg)))
345+
346+
return args

0 commit comments

Comments
 (0)