30
30
import lr
31
31
from paddle .distributed import fleet
32
32
from paddle .distributed .fleet .meta_parallel import get_rng_state_tracker
33
+ from paddle .distributed .fleet .meta_optimizers .dygraph_optimizer import DygraphShardingOptimizer
33
34
34
35
MODEL_CLASSES = {
35
36
"gpt" : (GPTForPretraining , GPTTokenizer ),
36
37
"gpt-cn" : (GPTForPretraining , GPTChineseTokenizer ),
37
38
}
38
39
39
40
40
- def set_hyrbid_parallel_seed (basic_seed , dp_rank , mp_rank , pp_rank ):
41
+ def set_hyrbid_parallel_seed (basic_seed , data_world_rank , mp_rank , pp_rank ):
41
42
assert args .device != "cpu"
42
43
43
- random .seed (basic_seed + dp_rank )
44
- np .random .seed (basic_seed + dp_rank )
45
- paddle .seed (basic_seed + dp_rank )
44
+ random .seed (basic_seed + data_world_rank )
45
+ np .random .seed (basic_seed + data_world_rank )
46
+ paddle .seed (basic_seed + data_world_rank )
46
47
47
48
# local_seed/ global_seed is used to control dropout in ModelParallel
48
49
local_seed = basic_seed + 123 + mp_rank * 10 + pp_rank * 1000
49
- global_seed = basic_seed + dp_rank
50
+ global_seed = basic_seed + data_world_rank
50
51
tracker = get_rng_state_tracker ()
51
52
tracker .add ('global_seed' , global_seed )
52
53
tracker .add ('local_seed' , local_seed )
@@ -92,14 +93,18 @@ def do_train(args):
92
93
strategy .hybrid_configs = {
93
94
"dp_degree" : args .dp_degree ,
94
95
"mp_degree" : args .mp_degree ,
95
- "pp_degree" : args .pp_degree
96
+ "pp_degree" : args .pp_degree ,
97
+ "sharding_degree" : args .sharding_degree
96
98
}
97
99
98
100
strategy .pipeline_configs = {
99
101
"accumulate_steps" : args .local_batch_size // args .micro_batch_size ,
100
102
"micro_batch_size" : args .micro_batch_size
101
103
}
102
104
105
+ # set control in tensor parallel
106
+ strategy .tensor_parallel_configs = {"tensor_init_seed" : args .seed }
107
+
103
108
fleet .init (is_collective = True , strategy = strategy )
104
109
105
110
# obtain rank message of hybrid parallel
@@ -108,10 +113,15 @@ def do_train(args):
108
113
mp_rank = hcg .get_model_parallel_rank ()
109
114
pp_rank = hcg .get_stage_id ()
110
115
dp_rank = hcg .get_data_parallel_rank ()
116
+ sharding_rank = hcg .get_sharding_parallel_rank ()
117
+
118
+ sharding_size = hcg .get_sharding_parallel_world_size ()
119
+ data_world_rank = dp_rank * sharding_size + sharding_rank
120
+ data_world_size = args .dp_degree * args .sharding_degree
111
121
local_rank = int (os .getenv ("PADDLE_RANK_IN_NODE" , 0 ))
112
122
113
123
# seed control in hybrid parallel
114
- set_hyrbid_parallel_seed (args .seed , dp_rank , mp_rank , pp_rank )
124
+ set_hyrbid_parallel_seed (args .seed , data_world_rank , mp_rank , pp_rank )
115
125
116
126
default_global_tokens_num = args .global_batch_size * args .max_seq_len
117
127
@@ -183,15 +193,31 @@ def do_train(args):
183
193
if not any (nd in n for nd in ["bias" , "norm" ])
184
194
]
185
195
186
- optimizer = paddle .optimizer .AdamW (
187
- learning_rate = lr_scheduler if lr_scheduler is not None else args .max_lr ,
188
- beta1 = args .adam_beta1 ,
189
- beta2 = args .adam_beta2 ,
190
- epsilon = args .adam_epsilon ,
191
- parameters = model .parameters (),
192
- weight_decay = args .weight_decay ,
193
- grad_clip = clip ,
194
- apply_decay_param_fun = lambda x : x in decay_params )
196
+ if args .sharding_degree > 1 :
197
+ optimizer = DygraphShardingOptimizer (
198
+ hcg = fleet .get_hybrid_communicate_group (),
199
+ user_defined_strategy = strategy ,
200
+ params = model .parameters (),
201
+ inner_optimizer_class = paddle .optimizer .AdamW ,
202
+ learning_rate = lr_scheduler
203
+ if lr_scheduler is not None else args .max_lr ,
204
+ beta1 = args .adam_beta1 ,
205
+ beta2 = args .adam_beta2 ,
206
+ epsilon = args .adam_epsilon ,
207
+ weight_decay = args .weight_decay ,
208
+ grad_clip = clip ,
209
+ apply_decay_param_fun = lambda x : x in decay_params )
210
+ else :
211
+ optimizer = paddle .optimizer .AdamW (
212
+ learning_rate = lr_scheduler
213
+ if lr_scheduler is not None else args .max_lr ,
214
+ beta1 = args .adam_beta1 ,
215
+ beta2 = args .adam_beta2 ,
216
+ epsilon = args .adam_epsilon ,
217
+ parameters = model .parameters (),
218
+ weight_decay = args .weight_decay ,
219
+ grad_clip = clip ,
220
+ apply_decay_param_fun = lambda x : x in decay_params )
195
221
196
222
if paddle .distributed .get_world_size () > 1 :
197
223
model = fleet .distributed_model (model )
@@ -227,8 +253,8 @@ def do_train(args):
227
253
args ,
228
254
data_file ,
229
255
local_rank = local_rank ,
230
- data_world_size = args . dp_degree ,
231
- data_world_rank = dp_rank ,
256
+ data_world_size = data_world_size ,
257
+ data_world_rank = data_world_rank ,
232
258
eos_id = tokenizer .eos_token_id )
233
259
# Bug fix, if not call valid_data_loader, the enumerate will call valid_data_loader
234
260
# many times. and start a new random dataloader.
@@ -309,6 +335,7 @@ def do_train(args):
309
335
args .eval_iters , log_writer , global_step ,
310
336
epoch , "valid" )
311
337
338
+ # TODO: 1. merge paramters while saving model. 2. ensure that the model is saved and loaded correctly
312
339
# only dp_rank = 0 save model
313
340
if (global_step % args .save_steps == 0 or
314
341
global_step >= args .max_steps ) and dp_rank == 0 :
@@ -322,24 +349,25 @@ def do_train(args):
322
349
logger .info ("Save model to %s" % output_dir )
323
350
324
351
if args .pp_degree > 1 :
325
- model_to_save .save_state_dict (output_dir )
326
- if mp_rank * pp_rank == 1 :
352
+ if mp_rank == 0 and sharding_rank == 0 and pp_rank == 0 :
327
353
tokenizer .save_pretrained (output_dir )
354
+ model_to_save .save_state_dict (output_dir )
328
355
paddle .save (
329
356
optimizer .state_dict (),
330
357
os .path .join (
331
358
output_dir ,
332
- "model_state_mp_{:0>2d}_pp_{:0>2d}.pdopt" .
333
- format (mp_rank , pp_rank )))
359
+ "model_state_mp_{:0>2d}_sharding_{:0>2d} _pp_{:0>2d}.pdopt" .
360
+ format (mp_rank , sharding_rank , pp_rank )))
334
361
else :
335
- path = os .path .join (output_dir ,
336
- 'model_{:0>2d}' .format (mp_rank ))
337
- os .makedirs (path , exist_ok = True )
338
- model_to_save .save_pretrained (path )
339
-
340
- paddle .save (optimizer .state_dict (),
341
- os .path .join (path , "model_state.pdopt" ))
342
- tokenizer .save_pretrained (path )
362
+ if mp_rank == 0 and sharding_rank == 0 :
363
+ tokenizer .save_pretrained (output_dir )
364
+ model_to_save .save_pretrained (output_dir )
365
+ paddle .save (
366
+ optimizer .state_dict (),
367
+ os .path .join (
368
+ output_dir ,
369
+ "model_state_mp_{:0>2d}_sharding_{:0>2d}.pdopt" .
370
+ format (mp_rank , sharding_rank )))
343
371
344
372
if global_step >= args .max_steps :
345
373
run_evaluate (args , test_data_loader , model , criterion ,
0 commit comments