Skip to content

Commit 08898bf

Browse files
authored
Cherry-Pick fast_safe_open (#8458)
* [Performance] Optimize unified checkpoint save/load speed. (#8204) * opt unified checkpoint save/load speed.
1 parent fc860a3 commit 08898bf

File tree

10 files changed

+490
-41
lines changed

10 files changed

+490
-41
lines changed

paddlenlp/trainer/plugins/unified_checkpoint.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from paddlenlp.transformers.model_utils import (
3131
PretrainedModel,
3232
_load_state_dict_into_model,
33+
faster_set_state_dict,
3334
get_parameter_dtype,
3435
load_state_dict,
3536
unwrap_model,
@@ -65,9 +66,10 @@
6566
from paddlenlp.utils.nested import nested_copy, nested_copy_place
6667

6768
if is_safetensors_available():
68-
from safetensors import safe_open
69+
# from safetensors import safe_open
6970
from safetensors.numpy import save_file as safe_save_file
7071

72+
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
7173

7274
FP32_MASTER = "fp32_master_0"
7375
optimizer_scalar_name = [
@@ -91,6 +93,11 @@
9193
async_save_queue = []
9294

9395

96+
DEST_PLACE = paddle.CPUPlace()
97+
if paddle.device.is_compiled_with_cuda():
98+
DEST_PLACE = paddle.CUDAPinnedPlace()
99+
100+
94101
class UnifiedCheckpointOption(ExplicitEnum):
95102
"""
96103
"- skip_save_model_weight: do not save model weights when the masters weight exist\n"
@@ -196,7 +203,6 @@ def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str,
196203
Returns:
197204
None
198205
"""
199-
200206
if paddle.distributed.get_world_size() <= 1:
201207
load_single_card_checkpoint(args, model, resume_from_checkpoint)
202208
return
@@ -222,7 +228,6 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa
222228
pretrained_model_name_or_path=resume_from_checkpoint,
223229
index_filename=os.path.join(resume_from_checkpoint, index_filename),
224230
)
225-
226231
loaded_keys = sharded_metadata["all_checkpoint_keys"]
227232

228233
model_state_dict = get_expected_state_dict(model)
@@ -266,7 +271,9 @@ def _remove_unused_keys(
266271
else:
267272
tp_actions = model.get_tensor_parallel_convert_actions(model.config, loaded_keys, ignore_error=True)
268273
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
269-
state_dict = load_state_dict(shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys)
274+
state_dict = load_state_dict(
275+
shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys, device="expected"
276+
)
270277

271278
if not pre_tensor_parallel_split:
272279
# Since we load all keys but we only need one of pipeline stages
@@ -279,11 +286,12 @@ def _remove_unused_keys(
279286
None, model.config, state_dict=state_dict, ignore_error=len(resolved_archive_file) > 1
280287
)
281288

282-
error_msgs += _load_state_dict_into_model(model, state_dict, "")
289+
# error_msgs += _load_state_dict_into_model(model, state_dict, "")
290+
error_msgs += faster_set_state_dict(model, state_dict, strict_dtype=False)
283291

284292
# force memory release
285293
del state_dict
286-
gc.collect()
294+
# gc.collect()
287295

288296
if len(error_msgs) > 0:
289297
error_msg = "\n\t".join(error_msgs)
@@ -337,6 +345,7 @@ def unified_checkpoint_into_shards(
337345
tp_actions = model_to_save.get_tensor_parallel_convert_actions(
338346
model_to_save.config, state_dict.keys(), is_split=False, ignore_error=True
339347
)
348+
logger.info("Unified model tensor parallel weights in shards")
340349
state_dict = merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys)
341350

342351
# build index json file
@@ -490,6 +499,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
490499
# This should always be a list but, just to be sure.
491500
if not isinstance(resolved_archive_file, list):
492501
resolved_archive_file = [resolved_archive_file]
502+
493503
if len(resolved_archive_file) > 1:
494504
resolved_archive_file = tqdm(resolved_archive_file, desc="Loading optimizer shards")
495505

@@ -537,10 +547,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
537547
tp_actions = mapping_optimizer_tp_actions(tp_actions, expected_keys)
538548

539549
# Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
540-
state_dict = load_state_dict(shard_file, tp_actions, expected_keys)
550+
state_dict = load_state_dict(shard_file, tp_actions, expected_keys, device="expected")
541551
else:
542552
# for pipeline model, we don't need to use tp_actions
543-
state_dict = load_state_dict(shard_file, None, expected_keys)
553+
state_dict = load_state_dict(shard_file, None, expected_keys, device="expected")
544554

545555
returned_state_dict.update(state_dict)
546556
# force memory release
@@ -553,7 +563,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
553563
state_dict_master_weight = load_resolved_archive_file(
554564
resolved_archive_file_mw, sharded_metadata_mw, expected_keys_mw, is_master_weights=True
555565
)
556-
557566
# rename optimizer param
558567
for key in list(state_dict_optim.keys()):
559568
key_name = key.split("/")
@@ -562,13 +571,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
562571
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
563572
else:
564573
key_name = "_".join([static_name, key_name[1]])
565-
returned_optim_state_dict[key_name] = state_dict_optim[key]
574+
returned_optim_state_dict[key_name] = state_dict_optim.pop(key)
566575
returned_optim_state_dict[key_name].name = key_name
567576

568577
if has_master_weights:
569578
for key in list(state_dict_master_weight.keys()):
570579
static_name = struct2static_name_mappings[key]
571-
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight[key]
580+
returned_optim_state_dict["master_weights"][static_name] = state_dict_master_weight.pop(key)
572581
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
573582

574583
returned_optim_state_dict = nested_copy_place(
@@ -640,6 +649,7 @@ def unified_optimizer_into_shards(
640649
tp_actions = model.get_tensor_parallel_convert_actions(
641650
model.config, model_keys, is_split=False, ignore_error=True
642651
)
652+
logger.info("Unified optimizer tensor parallel in shards")
643653
optim_state_dict = merge_tensor_parallel_for_optimizer(
644654
optim_state_dict,
645655
tp_actions,
@@ -648,6 +658,7 @@ def unified_optimizer_into_shards(
648658
paddle.device.cuda.empty_cache()
649659

650660
if master_weights is not None:
661+
logger.info("Unified master weight tensor parallel in shards")
651662
master_weights = merge_tensor_parallel_for_optimizer(
652663
master_weights,
653664
tp_actions,
@@ -703,7 +714,6 @@ def unified_optimizer_into_shards(
703714
def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serialization=False):
704715
index_filename = select_model_weight_index(args, model, resume_from_checkpoint, safe_serialization, local=False)
705716
index_filename = os.path.join(resume_from_checkpoint, index_filename)
706-
707717
# Find index json file and distribute this file in global group.
708718
if distributed_isfile(index_filename):
709719
distributed_file(index_filename)
@@ -1605,7 +1615,9 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False):
16051615
tp_group = hcg.get_model_parallel_group()
16061616
pp_group = hcg.get_pipe_parallel_group()
16071617

1608-
logger.info("Unified checkpoint generating sharded_index json files.")
1618+
logger.info(
1619+
f"Unified checkpoint: generating sharded_index json files for {'optimizer or master weight' if is_optimizer else 'model weight'}."
1620+
)
16091621

16101622
if tp_group.nranks > 1:
16111623
dist.all_gather_object(index_file_list, index_file, tp_group)
@@ -1714,8 +1726,6 @@ def filter_params(model_to_save, state_dict, is_optimizer=False):
17141726

17151727

17161728
def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
1717-
logger.info("Unified checkpoint merge tensor parallel in shards")
1718-
17191729
hcg = fleet.get_hybrid_communicate_group()
17201730
tp_group = hcg.get_model_parallel_group()
17211731
tp_rank = tp_group.rank
@@ -1741,7 +1751,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17411751
action = tp_actions.pop(key)
17421752
tensor = action(ret) if is_dst else None
17431753
else:
1744-
tensor = tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
1754+
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None
17451755

17461756
if is_dst:
17471757
state_dict_to_save[key] = tensor
@@ -1754,8 +1764,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17541764

17551765

17561766
def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys):
1757-
logger.info("Unified optimizer tensor parallel in shards")
1758-
1767+
# Core function for UC
17591768
hcg = fleet.get_hybrid_communicate_group()
17601769
tp_group = hcg.get_model_parallel_group()
17611770
tp_rank = tp_group.rank
@@ -1773,15 +1782,13 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
17731782
if model_key in tp_actions:
17741783
# for example: beta1, beta2
17751784
if tensor.numel().item() == 1:
1776-
tensor = (
1777-
tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
1778-
) # Need broadcast when loaded
1785+
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None # Need broadcast when loaded
17791786
else:
17801787
ret = distributed_gather(tensor, dst=j, group=tp_group, offload=False)
17811788
action = tp_actions[model_key]
17821789
tensor = action(ret) if is_dst else None
17831790
else:
1784-
tensor = tensor._copy_to(paddle.CPUPlace(), False) if is_dst else None
1791+
tensor = tensor._copy_to(DEST_PLACE, False) if is_dst else None
17851792

17861793
if is_dst:
17871794
state_dict_to_save[filter_keys[i]] = tensor

paddlenlp/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2419,6 +2419,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
24192419
self.runtime_timer.stop()
24202420
return
24212421

2422+
logger.info("Loading optimizer and scheduler...")
24222423
if (not self.args.should_load_sharding_stage1_model) and self.args.ignore_load_lr_and_optim:
24232424
self.runtime_timer.stop()
24242425
return

paddlenlp/transformers/conversion_utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,12 @@ def naive_fuse_merge_tp(weight_list, is_column=True, fuse_tensor_parts=2):
285285

286286
if isinstance(weight_list[0], np.ndarray):
287287
return np.concatenate([reorder[i] for i in index], axis=axis)
288+
else:
289+
tensor = paddle.concat([reorder[i] for i in index], axis=axis)
288290

289-
return paddle.concat([reorder[i] for i in index], axis=axis)._copy_to(paddle.CPUPlace(), False)
291+
if tensor.place.is_gpu_place():
292+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
293+
return tensor
290294

291295

292296
def naive_fuse_split_tp(
@@ -361,12 +365,18 @@ def normal_fuse_merge_tp(weight_list, is_column=True):
361365
if isinstance(weight_list[0], np.ndarray):
362366
return np.concatenate(weight_list, axis=-1)
363367
else:
364-
return paddle.concat(weight_list, axis=-1)._copy_to(paddle.CPUPlace(), False)
368+
tensor = paddle.concat(weight_list, axis=-1)
369+
if tensor.place.is_gpu_place():
370+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
371+
return tensor
365372
else:
366373
if isinstance(weight_list[0], np.ndarray):
367374
return np.concatenate(weight_list, axis=0)
368375
else:
369-
return paddle.concat(weight_list, axis=0)._copy_to(paddle.CPUPlace(), False)
376+
tensor = paddle.concat(weight_list, axis=0)
377+
if tensor.place.is_gpu_place():
378+
tensor = tensor._copy_to(paddle.CUDAPinnedPlace(), False)
379+
return tensor
370380

371381

372382
def normal_fuse_split_tp(weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True):

paddlenlp/transformers/model_utils.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,13 @@ def unwrap_optimizer(optimizer, optimizer_instances=()):
109109

110110
if is_safetensors_available():
111111

112-
from safetensors import safe_open
113-
from safetensors.numpy import load_file as safe_load_file
112+
# from safetensors import safe_open
113+
# from safetensors.numpy import load_file as safe_load_file
114114
from safetensors.numpy import save_file as safe_save_file
115115

116+
from paddlenlp.utils.safetensors import fast_load_file as safe_load_file
117+
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
118+
116119

117120
def prune_linear_layer(layer: nn.Linear, index: paddle.Tensor, dim: int = 0) -> nn.Linear:
118121
"""
@@ -313,7 +316,7 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:
313316

314317

315318
def load_state_dict(
316-
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None
319+
checkpoint_file: Union[str, os.PathLike], tensor_parallel_split_mapping=None, fliter_dict_keys=None, device="cpu"
317320
):
318321
"""
319322
Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
@@ -346,11 +349,16 @@ def load_state_dict(
346349
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
347350
else:
348351
weight = py_safe_slice_[:]
352+
if device == "expected":
353+
with device_guard():
354+
weight = paddle.Tensor(weight, zero_copy=True)
355+
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
349356
state_dict[key] = weight
350357

351-
for k in list(state_dict.keys()):
352-
with device_guard():
353-
state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True)
358+
if device == "cpu":
359+
for k in list(state_dict.keys()):
360+
with device_guard():
361+
state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True)
354362

355363
return state_dict
356364

@@ -672,8 +680,10 @@ def load_sharded_checkpoint(model, folder, variant=None, strict=True, prefer_saf
672680
return missing_keys, unexpected_keys
673681

674682

675-
def faster_set_state_dict(model, state_dict):
683+
def faster_set_state_dict(model, state_dict, strict_dtype=True):
676684
# the state_dict will be destroied.
685+
unused_keys = set(state_dict.keys())
686+
unset_keys = set(model.state_dict().keys())
677687
with paddle.no_grad():
678688
for k, v in model.state_dict().items():
679689
if k in state_dict:
@@ -683,8 +693,10 @@ def faster_set_state_dict(model, state_dict):
683693
f"faster_set_state_dict need state dict with paddle.Tensor, but got {type(v_new)}"
684694
)
685695
# 2. cast param / Tensor to dtype
696+
#
686697
if v.dtype != v_new.dtype:
687-
raise ValueError(f"for key: {k}, expect dtype {v.dtype}, but got {v_new.dtype}")
698+
if strict_dtype or (not v.is_floating_point() or not v_new.is_floating_point()):
699+
raise ValueError(f"for key: {k}, expect dtype {v.dtype}, but got {v_new.dtype}")
688700
# check shape
689701
if list(v.shape) != list(v_new.shape):
690702
raise ValueError(f"for key: {k}, expect shape {v.shape}, but got {v_new.shape}")
@@ -700,9 +712,22 @@ def faster_set_state_dict(model, state_dict):
700712
else:
701713
new_t = v_new
702714

715+
if not strict_dtype and v.dtype != new_t.dtype:
716+
new_t = new_t.astype(v.dtype)
717+
703718
# 4. share Tensor to origin param / Tensor
704719
src_tensor = new_t.value().get_tensor()
705720
dst_tensor._share_data_with(src_tensor)
721+
unset_keys.remove(k)
722+
unused_keys.remove(k)
723+
724+
error_msgs = []
725+
# if len(unset_keys) > 0:
726+
# error_msgs.append(f"Those weight of model is not initialized: {list(unset_keys)}")
727+
if len(unused_keys) > 0:
728+
error_msgs.append(f"Those state dict keys are not using in model: {list(unused_keys)}")
729+
730+
return error_msgs
706731

707732

708733
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
@@ -734,22 +759,16 @@ def _convert_state_dict_dtype_and_shape(state_dict, model_to_load):
734759
def is_0d_or_1d(tensor):
735760
return len(tensor.shape) == 0 or list(tensor.shape) == [1]
736761

737-
expected_place = paddle.framework._current_expected_place()
738762
for key, value in model_to_load.state_dict().items():
739-
if key in state_dict:
763+
if key in list(state_dict.keys()):
740764
if isinstance(state_dict[key], np.ndarray):
741765
raise ValueError(
742766
"convert_state_dict_dtype expected paddle.Tensor not numpy.ndarray, plase convert numpy.ndarray to paddle.Tensor"
743767
)
744768
# confirm parameter cast is executed on the same device as model
745769
# TODO: cast(FP32 -> FP16) has diff on different devices, need to fix it
746770
if state_dict[key].is_floating_point() and state_dict[key].dtype != value.dtype:
747-
value_pop = state_dict.pop(key)
748-
value_new_place = (
749-
value_pop if value_pop.place == expected_place else value_pop._copy_to(expected_place, False)
750-
)
751-
state_dict[key] = paddle.cast(value_new_place, value.dtype)._copy_to(value_pop.place, False)
752-
del value_new_place
771+
state_dict[key] = paddle.cast(state_dict.pop(key), value.dtype)
753772
# unified 0d and 1d tensor
754773
if is_0d_or_1d(value) and is_0d_or_1d(state_dict[key]):
755774
if list(value.shape) != list(state_dict[key].shape):

0 commit comments

Comments
 (0)