Skip to content

Commit 8363f88

Browse files
authored
Support HF torch load & save (#2437)
Co-authored-by: llbdyiu66 <[email protected]>
1 parent 6aba537 commit 8363f88

22 files changed

+520
-650
lines changed

paddleformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import TYPE_CHECKING
2020

2121
from .utils.lazy_import import _LazyModule
22+
from .utils.paddle_patch import *
2223

2324
PADDLEFORMERS_STABLE_VERSION = "PADDLEFORMERS_STABLE_VERSION"
2425

paddleformers/nn/attention/eager_attention.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,10 @@ def eager_attention_forward(
3131
is_causal: Optional[bool] = None,
3232
**kwargs,
3333
):
34-
num_key_value_heads = None
35-
if hasattr(module, "num_key_value_heads"):
36-
num_key_value_heads = module.num_key_value_heads
37-
elif hasattr(module, "num_key_value_groups"):
38-
num_key_value_heads = module.num_key_value_groups
39-
40-
if num_key_value_heads is not None:
41-
key = repeat_kv(key, module.num_key_value_heads)
42-
value = repeat_kv(value, module.num_key_value_heads)
34+
if hasattr(module, "num_key_value_groups"):
35+
num_key_value_groups = module.num_key_value_groups
36+
key = repeat_kv(key, num_key_value_groups)
37+
value = repeat_kv(value, num_key_value_groups)
4338

4439
perm = [0, 2, 1, 3] # b l h d -> b h l d
4540
query = paddle.transpose(x=query, perm=perm)

paddleformers/nn/attention/sdpa_attention.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import paddle.nn as nn
1919

2020
from ...utils.masking_utils import _gen_from_sparse_attn_mask_indices
21-
from .utils import repeat_kv
2221

2322

2423
def sdpa_attention_forward(
@@ -34,15 +33,6 @@ def sdpa_attention_forward(
3433
**kwargs,
3534
):
3635
# query: b l h d
37-
num_key_value_heads = None
38-
if hasattr(module, "num_key_value_heads"):
39-
num_key_value_heads = module.num_key_value_heads
40-
elif hasattr(module, "num_key_value_groups"):
41-
num_key_value_heads = module.num_key_value_groups
42-
43-
if num_key_value_heads is not None:
44-
key = repeat_kv(key, module.num_key_value_heads)
45-
value = repeat_kv(value, module.num_key_value_heads)
4636

4737
if is_causal is None and attn_mask_start_row_indices is None:
4838
is_causal = query.shape[1] > 1 and attention_mask is None and getattr(module, "is_causal", True)

paddleformers/trainer/trainer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,9 @@ def _load_from_peft_checkpoint(self, resume_from_checkpoint=None):
637637
elif isinstance(self.model, LoKrModel):
638638
weights_file = os.path.join(resume_from_checkpoint, LOKR_WEIGHTS_NAME)
639639
elif isinstance(self.model, ReFTModel):
640-
self.model.from_pretrained(resume_from_checkpoint, self.model.model)
640+
self.model.from_pretrained(
641+
resume_from_checkpoint, self.model.model, convert_from_hf=self.args.convert_from_hf
642+
)
641643
return
642644

643645
if self.args.dataset_rank == 0:
@@ -689,6 +691,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
689691
self.unified_checkpoint_handler.load_unified_checkpoint(
690692
self.model,
691693
resume_from_checkpoint,
694+
convert_from_hf=self.args.convert_from_hf,
692695
)
693696
if isinstance(self.model, LoRAModel) and self.model.lora_config.loraga:
694697
self.model.reinit_base_model = True
@@ -1452,6 +1455,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
14521455
self.unified_checkpoint_handler.load_unified_checkpoint(
14531456
self.model,
14541457
self.state.best_model_checkpoint,
1458+
convert_from_hf=self.args.convert_from_hf,
14551459
)
14561460
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
14571461
broadcast_dataset_rank0_model(self.model)
@@ -1502,6 +1506,7 @@ def _load_best_model_from_peft_checkpoint(self):
15021506
self.unified_checkpoint_handler.load_unified_checkpoint(
15031507
self.model,
15041508
self.state.best_model_checkpoint,
1509+
convert_from_hf=self.args.convert_from_hf,
15051510
)
15061511
if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1:
15071512
broadcast_dataset_rank0_model(self.model)
@@ -3010,7 +3015,9 @@ def _save(
30103015
# backup and remove unified_checkpoint_config for not trine stage
30113016
if not self.is_in_train:
30123017
self.args.unified_checkpoint_config = []
3013-
self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir, signal_dir)
3018+
self.unified_checkpoint_handler.save_unified_checkpoint(
3019+
self.model, self.optimizer, output_dir, signal_dir, save_to_hf=self.args.save_to_hf
3020+
)
30143021

30153022
# recover unified_checkpoint_config for not trine stage
30163023
if not self.is_in_train:
@@ -3034,6 +3041,7 @@ def _save(
30343041
merge_tensor_parallel=merge_tensor_parallel,
30353042
is_main_process=self.args.should_save,
30363043
max_shard_size="1024GB",
3044+
save_to_hf=self.args.save_to_hf,
30373045
)
30383046
# TODO: @ZHUI unify unwrap_model(self.model) and self.model
30393047
elif not isinstance(self.model, PretrainedModel):
@@ -3052,6 +3060,7 @@ def _save(
30523060
save_function=self._save_ckpt_func,
30533061
is_main_process=self.args.should_save,
30543062
max_shard_size="1024GB",
3063+
save_to_hf=self.args.save_to_hf,
30553064
)
30563065
else:
30573066
unwrap_model(self.model).save_pretrained(
@@ -3061,6 +3070,7 @@ def _save(
30613070
save_function=self._save_ckpt_func,
30623071
is_main_process=self.args.should_save,
30633072
max_shard_size="1024GB",
3073+
save_to_hf=self.args.save_to_hf,
30643074
)
30653075
else:
30663076
logger.info("Trainer.model is not a `PretrainedModel`, only saving its state dict.")
@@ -3093,6 +3103,7 @@ def _save(
30933103
save_function=self._save_ckpt_func,
30943104
is_main_process=self.args.should_save,
30953105
max_shard_size="1024GB",
3106+
save_to_hf=self.args.save_to_hf,
30963107
)
30973108
else:
30983109
self.model.save_pretrained(
@@ -3102,6 +3113,7 @@ def _save(
31023113
save_function=self._save_ckpt_func,
31033114
is_main_process=self.args.should_save,
31043115
max_shard_size="1024GB",
3116+
save_to_hf=self.args.save_to_hf,
31053117
)
31063118
if self.args.should_save_sharding_stage1_model:
31073119
model_meta = self.sharding_io.gather_distributed_model_meta()

paddleformers/trainer/training_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,14 @@ class TrainingArguments:
10801080
default=False,
10811081
metadata={"help": "是否开启单路sharding时global norm通信拆分全局通信组为pp通信和mp通信分别做"},
10821082
)
1083+
convert_from_hf: Optional[bool] = field(
1084+
default=False,
1085+
metadata={"help": "Load model from HuggingFace safetensors."},
1086+
)
1087+
save_to_hf: Optional[bool] = field(
1088+
default=False,
1089+
metadata={"help": "Save model to HuggingFace safetensors."},
1090+
)
10831091

10841092
def __post_init__(self):
10851093
world_size = paddle.distributed.get_world_size()

paddleformers/trainer/unified_checkpoint/async_handler.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import paddle
2222
import paddle.distributed as dist
2323

24+
from ...transformers.model_utils import prepare_safe_save_state_dict
2425
from ...transformers.utils import is_safetensors_available
2526
from ...utils.log import logger
2627

@@ -70,16 +71,20 @@ def __init__(self, args):
7071
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)
7172

7273
def _file_save_async_or_sync(
73-
self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight", ckpt_quant_stage="O0"
74+
self,
75+
state_dict,
76+
path,
77+
signal_path=None,
78+
is_sync=True,
79+
state_dict_type="model_weight",
80+
ckpt_quant_stage="O0",
81+
save_to_hf=False,
7482
):
7583
if is_sync:
76-
for k in list(state_dict.keys()):
77-
if isinstance(state_dict[k], paddle.Tensor):
78-
state_dict[k] = state_dict.pop(k).cpu().numpy()
79-
84+
state_dict, metadata = prepare_safe_save_state_dict(state_dict, save_to_hf=save_to_hf)
8085
if state_dict_type == "optimizer_weight" and ckpt_quant_stage != "O0":
8186
state_dict = quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage)
82-
safe_save_file(state_dict, path, metadata={"format": "np"})
87+
safe_save_file(state_dict, path, metadata=metadata)
8388
else:
8489
if len(state_dict.keys()) == 0:
8590
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{self.global_rank}")
@@ -107,6 +112,8 @@ def _file_save_async_or_sync(
107112
self._lock,
108113
state_dict_type,
109114
self.global_rank,
115+
ckpt_quant_stage,
116+
save_to_hf,
110117
),
111118
)
112119
self._process_model_weight.start()
@@ -134,6 +141,8 @@ def _file_save_async_or_sync(
134141
if "skip_save_model_weight" in self.args.unified_checkpoint_config
135142
else state_dict_type,
136143
self.global_rank,
144+
ckpt_quant_stage,
145+
save_to_hf,
137146
),
138147
)
139148
self._process_master_weight.start()
@@ -160,6 +169,7 @@ def _file_save_async_or_sync(
160169
state_dict_type,
161170
self.global_rank,
162171
ckpt_quant_stage,
172+
save_to_hf,
163173
),
164174
)
165175
self._process_optimizer_weight.start()
@@ -191,6 +201,7 @@ def _save_file_async_in_process(
191201
state_dict_type,
192202
global_rank,
193203
ckpt_quant_stage="O0",
204+
save_to_hf=False,
194205
):
195206
shm = shared_memory.SharedMemory(name=shm_name)
196207
while True:
@@ -208,7 +219,8 @@ def _save_file_async_in_process(
208219
state_dict = quant_unified_optimizer(
209220
state_dict, state_dict_type, ckpt_quant_stage, async_save=True
210221
) # ckpt quantization
211-
safe_save_file(state_dict, path, {"format": "np"})
222+
metadata = {"format": "pt"} if save_to_hf else {"format": "np"}
223+
safe_save_file(state_dict, path, metadata=metadata)
212224
del state_dict
213225
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
214226
paddle.save(global_rank, saved_signal_path)

paddleformers/trainer/unified_checkpoint/load_local.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@
5050
__all__ = ["load_unified_checkpoint_locally", "load_unified_optimizer_locally"]
5151

5252

53-
def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False):
53+
def load_unified_checkpoint_locally(
54+
args, model, resume_from_checkpoint: str, safe_serialization=False, convert_from_hf=False
55+
):
5456
"""
5557
Only dataset_rank == 0 or using expert parallel can enter this function.
5658
"""
@@ -114,8 +116,14 @@ def _remove_unused_keys(
114116
else:
115117
tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True)
116118
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
119+
transpose_weight_keys = getattr(model, "transpose_weight_keys", None)
117120
state_dict = load_state_dict(
118-
shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys, device="expected"
121+
shard_file,
122+
tp_actions if pre_tensor_parallel_split else None,
123+
expected_keys,
124+
device="expected",
125+
convert_from_hf=convert_from_hf,
126+
transpose_weight_keys=transpose_weight_keys,
119127
)
120128

121129
if not pre_tensor_parallel_split:

paddleformers/trainer/unified_checkpoint/load_save_single_card.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
import paddle
2121

2222
from ...peft import LoRAModel, PrefixModelForCausalLM
23-
from ...transformers.model_utils import _load_state_dict_into_model, load_state_dict
23+
from ...transformers.conversion_utils import ConversionMixin
24+
from ...transformers.model_utils import (
25+
_load_state_dict_into_model,
26+
load_state_dict,
27+
prepare_safe_save_state_dict,
28+
)
2429
from ...transformers.utils import (
2530
dtype_byte_size,
2631
get_checkpoint_shard_files,
@@ -54,17 +59,19 @@
5459
]
5560

5661

57-
def save_file_sync(state_dict, path):
58-
for k in list(state_dict.keys()):
59-
if isinstance(state_dict[k], paddle.Tensor):
60-
state_dict[k] = state_dict.pop(k).cpu().numpy()
61-
safe_save_file(state_dict, path, metadata={"format": "np"})
62+
def save_file_sync(state_dict, path, save_to_hf=False):
63+
state_dict, metadata = prepare_safe_save_state_dict(state_dict, save_to_hf=save_to_hf)
64+
safe_save_file(state_dict, path, metadata=metadata)
6265

6366

64-
def save_single_card_checkpoint(model_to_save, output_dir):
67+
def save_single_card_checkpoint(model_to_save, output_dir, save_to_hf=False):
6568
"""Save checkpoint for non-distributed environment."""
6669

6770
state_dict = get_expected_state_dict(model_to_save, concat_additional_adapter=True)
71+
if save_to_hf:
72+
transpose_weight_keys = getattr(model_to_save, "transpose_weight_keys", None)
73+
state_dict = ConversionMixin.convert_transpose_selected_weights(state_dict, transpose_weight_keys)
74+
6875
if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM):
6976
weight_filename = "peft_model-00001-of-00001.safetensors"
7077
index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME
@@ -92,7 +99,7 @@ def save_single_card_checkpoint(model_to_save, output_dir):
9299

93100
# save checkpoint, do no support asynchronous save for single card currently.
94101
logger.warning("Asynchronous saving is not supported for single card environment currently.")
95-
save_file_sync(state_dict, path=os.path.join(output_dir, weight_filename))
102+
save_file_sync(state_dict, path=os.path.join(output_dir, weight_filename), save_to_hf=save_to_hf)
96103

97104
save_model_config(model_to_save, output_dir)
98105

@@ -162,7 +169,7 @@ def save_single_card_optimizer(model, optimizer, output_dir):
162169
save_file_sync(master_weights, path=os.path.join(output_dir, "master_weights-00001-of-00001.safetensors"))
163170

164171

165-
def load_single_card_checkpoint(model, resume_from_checkpoint: str):
172+
def load_single_card_checkpoint(model, resume_from_checkpoint: str, convert_from_hf=False):
166173
if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM):
167174
index_filename = SAFE_PEFT_WEIGHTS_INDEX_NAME
168175
else:
@@ -180,7 +187,14 @@ def load_single_card_checkpoint(model, resume_from_checkpoint: str):
180187
if len(missing_keys) > 0:
181188
raise ValueError(f"Missing keys: {missing_keys}")
182189

183-
state_dict = load_state_dict(resolved_archive_file[0], None, expected_keys)
190+
transpose_weight_keys = getattr(model, "transpose_weight_keys", None)
191+
state_dict = load_state_dict(
192+
resolved_archive_file[0],
193+
None,
194+
expected_keys,
195+
convert_from_hf=convert_from_hf,
196+
transpose_weight_keys=transpose_weight_keys,
197+
)
184198
error_msgs = _load_state_dict_into_model(model, state_dict, "")
185199
del state_dict
186200
gc.collect()

0 commit comments

Comments
 (0)