Skip to content

Commit d4acbfc

Browse files
authored
[Unified Checkpoint] Fix last checkpoint save (#7810)
* fix(unified checkpoint): fix last ckpt save * fix(unified checkpoint): unified config format
1 parent 672ee98 commit d4acbfc

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2027,8 +2027,17 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
20272027
self.model_wrapped.get_all_parameters(convert2cpu=True)
20282028

20292029
if self.args.should_save_model_state:
2030+
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
2031+
# backup and remove unified_checkpoint_config for not trine stage
2032+
if not self.is_in_train:
2033+
self.args.unified_checkpoint_config = []
2034+
20302035
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
20312036

2037+
# recover unified_checkpoint_config for not trine stage
2038+
if not self.is_in_train:
2039+
self.args.unified_checkpoint_config = unified_checkpoint_config_backup
2040+
20322041
def _save_checkpoint(self, model, metrics=None):
20332042
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
20342043

paddlenlp/trainer/training_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,8 @@ def is_segment_parallel_supported():
13101310
"master_weight_compatible",
13111311
"async_save",
13121312
]
1313+
else:
1314+
self.unified_checkpoint_config = self.unified_checkpoint_config.split(" ")
13131315

13141316
if self.report_to is None:
13151317
logger.info(

0 commit comments

Comments
 (0)