115
115
nested_numpify ,
116
116
nested_truncate ,
117
117
)
118
+ from .utils .sharding_io import ShardingIO
118
119
119
120
DEFAULT_CALLBACKS = [DefaultFlowCallback ]
120
121
DEFAULT_PROGRESS_CALLBACK = ProgressCallback
@@ -285,6 +286,9 @@ def __init__(
285
286
self .control = TrainerControl ()
286
287
self ._signature_columns = None
287
288
self .optimizer_grouped_parameters = None
289
+ self .sharding_io = None
290
+ if self .args .should_save_sharding_stage1_model or self .args .should_load_sharding_stage1_model :
291
+ self .sharding_io = ShardingIO (self .args , self .model , self .optimizer )
288
292
289
293
if self .sharding is not None and self .optimizer is not None :
290
294
raise RuntimeError (
@@ -428,33 +432,52 @@ def load_state_dict_from_checkpoint(self, resume_from_checkpoint=None):
428
432
if resume_from_checkpoint is None :
429
433
raise ValueError (f"No valid checkpoint found in output directory ({ self .args .output_dir } )" )
430
434
431
- if resume_from_checkpoint is not None and self .args .dataset_rank == 0 :
432
- if isinstance (self .model , LoRAModel ):
433
- weight_name = LORA_WEIGHTS_NAME
434
- elif isinstance (self .model , PrefixModelForCausalLM ):
435
- weight_name = PREFIX_WEIGHTS_NAME
436
- else :
437
- weight_name = PADDLE_WEIGHTS_NAME
435
+ if isinstance (self .model , LoRAModel ):
436
+ weight_name = LORA_WEIGHTS_NAME
437
+ elif isinstance (self .model , PrefixModelForCausalLM ):
438
+ weight_name = PREFIX_WEIGHTS_NAME
439
+ else :
440
+ weight_name = PADDLE_WEIGHTS_NAME
438
441
439
- if not os .path .isfile (
440
- os .path .join (resume_from_checkpoint , _add_variant (weight_name , self .args .weight_name_suffix ))
441
- ):
442
- raise ValueError (f"Can't find a valid checkpoint at { resume_from_checkpoint } " )
442
+ if self .args .should_load_sharding_stage1_model :
443
+ state_dict = self .sharding_io .load_state_dict_from_checkpoint_with_reshard (
444
+ resume_from_checkpoint ,
445
+ base_weight_name = weight_name ,
446
+ )
447
+ else :
448
+ if resume_from_checkpoint is not None and self .args .dataset_rank == 0 :
449
+ file_path = os .path .join (
450
+ resume_from_checkpoint , _add_variant (weight_name , self .args .weight_name_suffix )
451
+ )
452
+ if not os .path .isfile (file_path ):
453
+ raise ValueError (f"Can't find a valid checkpoint at { resume_from_checkpoint } , no { file_path } " )
443
454
444
- logger .info (f"Loading model from { resume_from_checkpoint } ." )
455
+ logger .info (f"Loading model from { resume_from_checkpoint } ." )
445
456
446
- # We load the model state dict on the CPU to avoid an OOM error.
447
- state_dict = paddle .load (
448
- os .path .join (resume_from_checkpoint , _add_variant (weight_name , self .args .weight_name_suffix )),
449
- return_numpy = True ,
450
- )
451
- # If the model is on the GPU, it still works!
452
- self ._set_state_dict_in_model (state_dict )
457
+ # We load the model state dict on the CPU to avoid an OOM error.
458
+ state_dict = paddle .load (file_path , return_numpy = True )
453
459
454
- # release memory
455
- del state_dict
456
- elif resume_from_checkpoint is not None :
457
- logger .info (f"not loading ckpt :{ self .args .dataset_rank } " )
460
+ # If the model is on the GPU, it still works!
461
+ self ._set_state_dict_in_model (state_dict )
462
+ # release memory
463
+ del state_dict
464
+ elif resume_from_checkpoint is not None :
465
+ logger .info (f"not loading ckpt :{ self .args .dataset_rank } " )
466
+
467
+ def _wrap_model_and_load_sharded_checkpoint (self , resume_from_checkpoint ):
468
+ # In the sharded mode, should invoke load_state_dict_from_checkpoint after _wrap_model.
469
+ # In this mode, each sharding rank load sharded params, do not need to implement the broadcast logic.
470
+ model = self ._wrap_model (self .model_wrapped )
471
+ if self .sharding_io is not None :
472
+ # the self.optimizer should be wrapped and it is done in _wrap_model
473
+ self .sharding_io .set_optimizer (self .optimizer )
474
+ if model is not self .model :
475
+ self .model_wrapped = model
476
+ # Should invoke load_state_dict_from_checpoint after _load_optimizer_and_scheduler
477
+ # because the load_state_dict_from_checkpoint method rely on the optimizer in the shareded mode.
478
+ self ._load_optimizer_and_scheduler (resume_from_checkpoint )
479
+ self .load_state_dict_from_checkpoint (resume_from_checkpoint )
480
+ return model
458
481
459
482
def train (
460
483
self ,
@@ -475,43 +498,12 @@ def train(
475
498
"""
476
499
args = self .args
477
500
self .is_in_train = True
478
- resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint
479
501
480
502
# memory metrics - must set up as early as possible
481
503
self ._memory_tracker .start ()
482
504
483
- # Load potential model checkpoint
484
- if isinstance (resume_from_checkpoint , bool ) and resume_from_checkpoint :
485
- resume_from_checkpoint = get_last_checkpoint (args .output_dir )
486
- if resume_from_checkpoint is None :
487
- raise ValueError (f"No valid checkpoint found in output directory ({ args .output_dir } )" )
488
-
489
- if resume_from_checkpoint is not None and self .args .dataset_rank == 0 :
490
- if isinstance (self .model , LoRAModel ):
491
- weight_name = LORA_WEIGHTS_NAME
492
- elif isinstance (self .model , PrefixModelForCausalLM ):
493
- weight_name = PREFIX_WEIGHTS_NAME
494
- else :
495
- weight_name = PADDLE_WEIGHTS_NAME
496
- if not os .path .isfile (
497
- os .path .join (resume_from_checkpoint , _add_variant (weight_name , self .args .weight_name_suffix ))
498
- ):
499
- raise ValueError (f"Can't find a valid checkpoint at { resume_from_checkpoint } " )
500
-
501
- logger .info (f"Loading model from { resume_from_checkpoint } ." )
502
-
503
- # TODO: Need to load the model state dict on the CPU to avoid an OOM error.
504
- state_dict = paddle .load (
505
- os .path .join (resume_from_checkpoint , _add_variant (weight_name , self .args .weight_name_suffix )),
506
- return_numpy = True ,
507
- )
508
- # If the model is on the GPU, it still works!
509
- self ._set_state_dict_in_model (state_dict )
510
-
511
- # release memory
512
- del state_dict
513
- elif resume_from_checkpoint is not None :
514
- logger .info (f"not loading ckpt :{ self .args .dataset_rank } " )
505
+ if not self .args .should_load_sharding_stage1_model :
506
+ self .load_state_dict_from_checkpoint (resume_from_checkpoint )
515
507
516
508
train_dataloader = self .get_train_dataloader ()
517
509
@@ -563,17 +555,30 @@ def train(
563
555
564
556
self .state = TrainerState ()
565
557
566
- model = self ._wrap_model (self .model_wrapped )
567
-
568
- # for the rest of this function `model` is the outside model, whether it was wrapped or not
569
- if model is not self .model :
570
- self .model_wrapped = model
571
-
572
- if delay_optimizer_creation :
573
- self .create_optimizer_and_scheduler (num_training_steps = max_steps )
574
-
575
- # Check if saved optimizer or scheduler states exist
576
- self ._load_optimizer_and_scheduler (resume_from_checkpoint )
558
+ if self .args .should_load_sharding_stage1_model :
559
+ model = self ._wrap_model_and_load_sharded_checkpoint (resume_from_checkpoint )
560
+ elif self .args .should_save_sharding_stage1_model :
561
+ # In the non-sharded mode, should invoke load_state_dict_from_checkpoint before _wrap_model.
562
+ # In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
563
+ model = self ._wrap_model (self .model_wrapped )
564
+ if self .sharding_io is not None :
565
+ assert delay_optimizer_creation is False , "delay_optimizer_creation should be False"
566
+ # the self.optimizer should be wrapped and it is done in _wrap_model
567
+ self .sharding_io .set_optimizer (self .optimizer )
568
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
569
+ if model is not self .model :
570
+ self .model_wrapped = model
571
+ if delay_optimizer_creation :
572
+ self .create_optimizer_and_scheduler (num_training_steps = max_steps )
573
+ self ._load_optimizer_and_scheduler (resume_from_checkpoint )
574
+ else :
575
+ model = self ._wrap_model (self .model_wrapped )
576
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
577
+ if model is not self .model :
578
+ self .model_wrapped = model
579
+ if delay_optimizer_creation :
580
+ self .create_optimizer_and_scheduler (num_training_steps = max_steps )
581
+ self ._load_optimizer_and_scheduler (resume_from_checkpoint )
577
582
578
583
logger .info ("***** Running training *****" )
579
584
logger .info (f" Num examples = { num_examples :,} " )
@@ -1893,12 +1898,26 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
1893
1898
)
1894
1899
elif not isinstance (self .model , PretrainedModel ):
1895
1900
if isinstance (unwrap_model (self .model ), PretrainedModel ):
1896
- unwrap_model (self .model ).save_pretrained (
1897
- output_dir ,
1898
- merge_tensor_parallel = merge_tensor_parallel ,
1899
- variant = self .args .weight_name_suffix ,
1900
- is_main_process = self .args .should_save ,
1901
- )
1901
+ if self .args .should_save_sharding_stage1_model :
1902
+ config_to_save = None
1903
+ state_dict , config_to_save , weight_name_suffix = self .sharding_io .manipulate_state_dict_and_config (
1904
+ unwrap_model (self .model ), merge_tensor_parallel = merge_tensor_parallel
1905
+ )
1906
+ unwrap_model (self .model ).save_pretrained (
1907
+ output_dir ,
1908
+ state_dict = state_dict ,
1909
+ config_to_save = config_to_save ,
1910
+ merge_tensor_parallel = merge_tensor_parallel ,
1911
+ variant = weight_name_suffix ,
1912
+ is_main_process = self .args .should_save ,
1913
+ )
1914
+ else :
1915
+ unwrap_model (self .model ).save_pretrained (
1916
+ output_dir ,
1917
+ merge_tensor_parallel = merge_tensor_parallel ,
1918
+ variant = self .args .weight_name_suffix ,
1919
+ is_main_process = self .args .should_save ,
1920
+ )
1902
1921
else :
1903
1922
logger .info ("Trainer.model is not a `PretrainedModel`, only saving its state dict." )
1904
1923
if merge_tensor_parallel :
@@ -1910,12 +1929,28 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
1910
1929
os .path .join (output_dir , _add_variant (PADDLE_WEIGHTS_NAME , self .args .weight_name_suffix )),
1911
1930
)
1912
1931
else :
1913
- self .model .save_pretrained (
1914
- output_dir ,
1915
- merge_tensor_parallel = merge_tensor_parallel ,
1916
- variant = self .args .weight_name_suffix ,
1917
- is_main_process = self .args .should_save ,
1918
- )
1932
+ if isinstance (self .model , PretrainedModel ) and self .args .should_save_sharding_stage1_model :
1933
+ config_to_save = None
1934
+ state_dict , config_to_save , weight_name_suffix = self .sharding_io .manipulate_state_dict_and_config (
1935
+ self .model , merge_tensor_parallel = merge_tensor_parallel
1936
+ )
1937
+ self .model .save_pretrained (
1938
+ output_dir ,
1939
+ state_dict = state_dict ,
1940
+ config_to_save = config_to_save ,
1941
+ merge_tensor_parallel = merge_tensor_parallel ,
1942
+ variant = weight_name_suffix ,
1943
+ is_main_process = self .args .should_save ,
1944
+ )
1945
+ else :
1946
+ self .model .save_pretrained (
1947
+ output_dir ,
1948
+ merge_tensor_parallel = merge_tensor_parallel ,
1949
+ variant = self .args .weight_name_suffix ,
1950
+ is_main_process = self .args .should_save ,
1951
+ )
1952
+ if self .args .should_save_sharding_stage1_model :
1953
+ self .sharding_io .save_distributed_model_meta (output_dir )
1919
1954
1920
1955
if self .args .should_save :
1921
1956
if self .tokenizer is not None :
@@ -1929,13 +1964,20 @@ def _load_optimizer_and_scheduler(self, checkpoint):
1929
1964
if checkpoint is None :
1930
1965
return
1931
1966
1932
- optimizer_name = _add_variant (OPTIMIZER_NAME , self .args .optimizer_name_suffix )
1967
+ opt_state_dict = None
1968
+ if self .args .should_load_sharding_stage1_model :
1969
+ opt_state_dict = self .sharding_io .load_optimizer_state_with_reshard (
1970
+ checkpoint , base_opt_name = OPTIMIZER_NAME
1971
+ )
1972
+ else :
1973
+ optimizer_name = _add_variant (OPTIMIZER_NAME , self .args .optimizer_name_suffix )
1974
+ path = os .path .join (checkpoint , optimizer_name )
1975
+ if os .path .isfile (path ):
1976
+ opt_state_dict = paddlenlp_load (path , map_location = "cpu" )
1933
1977
1934
- if os .path .isfile (os .path .join (checkpoint , optimizer_name )) and os .path .isfile (
1935
- os .path .join (checkpoint , SCHEDULER_NAME )
1936
- ):
1978
+ if opt_state_dict is not None and os .path .isfile (os .path .join (checkpoint , SCHEDULER_NAME )):
1937
1979
# Load in optimizer and scheduler states
1938
- self .optimizer .set_state_dict (paddlenlp_load ( os . path . join ( checkpoint , optimizer_name ), map_location = "cpu" ) )
1980
+ self .optimizer .set_state_dict (opt_state_dict )
1939
1981
1940
1982
self .lr_scheduler .set_state_dict (paddle .load (os .path .join (checkpoint , SCHEDULER_NAME )))
1941
1983
if self .do_grad_scaling and os .path .isfile (os .path .join (checkpoint , SCALER_NAME )):
0 commit comments