File tree Expand file tree Collapse file tree 3 files changed +386
-0
lines changed Expand file tree Collapse file tree 3 files changed +386
-0
lines changed Original file line number Diff line number Diff line change 170
170
nested_numpify ,
171
171
nested_truncate ,
172
172
)
173
+ from .utils .load_hf_ckpt import load_huggingface_ckpt
173
174
from .utils .sharding_io import ShardingIO
174
175
175
176
DEFAULT_CALLBACKS = [DefaultFlowCallback ]
@@ -1009,6 +1010,9 @@ def _inner_training_loop(
1009
1010
if self .args .ignore_data_skip :
1010
1011
self .timers and self .timers ("read-data" ).start ()
1011
1012
1013
+ if self .args .resume_from_huggingface_ckpt is not None :
1014
+ load_huggingface_ckpt (model , self .args .resume_from_huggingface_ckpt )
1015
+
1012
1016
for epoch in range (epochs_trained , num_train_epochs ):
1013
1017
if isinstance (train_dataloader , paddle .io .DataLoader ) and isinstance (
1014
1018
train_dataloader .batch_sampler , DistributedBatchSampler
Original file line number Diff line number Diff line change @@ -885,6 +885,10 @@ class TrainingArguments:
885
885
default = None ,
886
886
metadata = {"help" : "The path to a folder with a valid checkpoint for your model." },
887
887
)
888
+ resume_from_huggingface_ckpt : Optional [str ] = field (
889
+ default = None ,
890
+ metadata = {"help" : "The path to a folder with a valid huggingface checkpoint for your model." },
891
+ )
888
892
auto_parallel_resume_form_hybrid_parallel : Optional [bool ] = field (
889
893
default = False ,
890
894
metadata = {"help" : "Wether hybrid paralle checkpoints be loaded in auto parallel mode." },
You can’t perform that action at this time.
0 commit comments