@@ -637,7 +637,9 @@ def _load_from_peft_checkpoint(self, resume_from_checkpoint=None):
637
637
elif isinstance (self .model , LoKrModel ):
638
638
weights_file = os .path .join (resume_from_checkpoint , LOKR_WEIGHTS_NAME )
639
639
elif isinstance (self .model , ReFTModel ):
640
- self .model .from_pretrained (resume_from_checkpoint , self .model .model )
640
+ self .model .from_pretrained (
641
+ resume_from_checkpoint , self .model .model , convert_from_hf = self .args .convert_from_hf
642
+ )
641
643
return
642
644
643
645
if self .args .dataset_rank == 0 :
@@ -689,6 +691,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
689
691
self .unified_checkpoint_handler .load_unified_checkpoint (
690
692
self .model ,
691
693
resume_from_checkpoint ,
694
+ convert_from_hf = self .args .convert_from_hf ,
692
695
)
693
696
if isinstance (self .model , LoRAModel ) and self .model .lora_config .loraga :
694
697
self .model .reinit_base_model = True
@@ -1452,6 +1455,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
1452
1455
self .unified_checkpoint_handler .load_unified_checkpoint (
1453
1456
self .model ,
1454
1457
self .state .best_model_checkpoint ,
1458
+ convert_from_hf = self .args .convert_from_hf ,
1455
1459
)
1456
1460
if self .args .sharding_parallel_degree > 1 or self .args .data_parallel_degree > 1 :
1457
1461
broadcast_dataset_rank0_model (self .model )
@@ -1502,6 +1506,7 @@ def _load_best_model_from_peft_checkpoint(self):
1502
1506
self .unified_checkpoint_handler .load_unified_checkpoint (
1503
1507
self .model ,
1504
1508
self .state .best_model_checkpoint ,
1509
+ convert_from_hf = self .args .convert_from_hf ,
1505
1510
)
1506
1511
if self .args .sharding_parallel_degree > 1 or self .args .data_parallel_degree > 1 :
1507
1512
broadcast_dataset_rank0_model (self .model )
@@ -3010,7 +3015,9 @@ def _save(
3010
3015
# backup and remove unified_checkpoint_config for not trine stage
3011
3016
if not self .is_in_train :
3012
3017
self .args .unified_checkpoint_config = []
3013
- self .unified_checkpoint_handler .save_unified_checkpoint (self .model , self .optimizer , output_dir , signal_dir )
3018
+ self .unified_checkpoint_handler .save_unified_checkpoint (
3019
+ self .model , self .optimizer , output_dir , signal_dir , save_to_hf = self .args .save_to_hf
3020
+ )
3014
3021
3015
3022
# recover unified_checkpoint_config for not trine stage
3016
3023
if not self .is_in_train :
@@ -3034,6 +3041,7 @@ def _save(
3034
3041
merge_tensor_parallel = merge_tensor_parallel ,
3035
3042
is_main_process = self .args .should_save ,
3036
3043
max_shard_size = "1024GB" ,
3044
+ save_to_hf = self .args .save_to_hf ,
3037
3045
)
3038
3046
# TODO: @ZHUI unify unwrap_model(self.model) and self.model
3039
3047
elif not isinstance (self .model , PretrainedModel ):
@@ -3052,6 +3060,7 @@ def _save(
3052
3060
save_function = self ._save_ckpt_func ,
3053
3061
is_main_process = self .args .should_save ,
3054
3062
max_shard_size = "1024GB" ,
3063
+ save_to_hf = self .args .save_to_hf ,
3055
3064
)
3056
3065
else :
3057
3066
unwrap_model (self .model ).save_pretrained (
@@ -3061,6 +3070,7 @@ def _save(
3061
3070
save_function = self ._save_ckpt_func ,
3062
3071
is_main_process = self .args .should_save ,
3063
3072
max_shard_size = "1024GB" ,
3073
+ save_to_hf = self .args .save_to_hf ,
3064
3074
)
3065
3075
else :
3066
3076
logger .info ("Trainer.model is not a `PretrainedModel`, only saving its state dict." )
@@ -3093,6 +3103,7 @@ def _save(
3093
3103
save_function = self ._save_ckpt_func ,
3094
3104
is_main_process = self .args .should_save ,
3095
3105
max_shard_size = "1024GB" ,
3106
+ save_to_hf = self .args .save_to_hf ,
3096
3107
)
3097
3108
else :
3098
3109
self .model .save_pretrained (
@@ -3102,6 +3113,7 @@ def _save(
3102
3113
save_function = self ._save_ckpt_func ,
3103
3114
is_main_process = self .args .should_save ,
3104
3115
max_shard_size = "1024GB" ,
3116
+ save_to_hf = self .args .save_to_hf ,
3105
3117
)
3106
3118
if self .args .should_save_sharding_stage1_model :
3107
3119
model_meta = self .sharding_io .gather_distributed_model_meta ()
0 commit comments