Skip to content

Commit a3e2344

Browse files
authored
support load hf ckpt (#10976)
1 parent 73f451c commit a3e2344

File tree

3 files changed

+386
-0
lines changed

3 files changed

+386
-0
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@
170170
nested_numpify,
171171
nested_truncate,
172172
)
173+
from .utils.load_hf_ckpt import load_huggingface_ckpt
173174
from .utils.sharding_io import ShardingIO
174175

175176
DEFAULT_CALLBACKS = [DefaultFlowCallback]
@@ -1009,6 +1010,9 @@ def _inner_training_loop(
10091010
if self.args.ignore_data_skip:
10101011
self.timers and self.timers("read-data").start()
10111012

1013+
if self.args.resume_from_huggingface_ckpt is not None:
1014+
load_huggingface_ckpt(model, self.args.resume_from_huggingface_ckpt)
1015+
10121016
for epoch in range(epochs_trained, num_train_epochs):
10131017
if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance(
10141018
train_dataloader.batch_sampler, DistributedBatchSampler

paddlenlp/trainer/training_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,10 @@ class TrainingArguments:
885885
default=None,
886886
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
887887
)
888+
resume_from_huggingface_ckpt: Optional[str] = field(
889+
default=None,
890+
metadata={"help": "The path to a folder with a valid huggingface checkpoint for your model."},
891+
)
888892
auto_parallel_resume_form_hybrid_parallel: Optional[bool] = field(
889893
default=False,
890894
metadata={"help": "Wether hybrid paralle checkpoints be loaded in auto parallel mode."},

0 commit comments

Comments
 (0)