Skip to content

Commit 9924a12

Browse files
Merge branch 'PaddlePaddle:develop' into feat/model-unittest-ci-action
2 parents 4a99803 + 89ec6b3 commit 9924a12

File tree

4 files changed

+95
-75
lines changed

4 files changed

+95
-75
lines changed

paddleformers/trainer/unified_checkpoint/async_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ...utils.log import logger
2727

2828
if is_safetensors_available():
29-
from safetensors.paddle import save_file as safe_save_file
29+
from safetensors.numpy import save_file as safe_save_file
3030

3131
from ...quantization.unified_checkpoint_quantization import quant_unified_optimizer
3232
from .shared_memory_utils import (
@@ -219,7 +219,7 @@ def _save_file_async_in_process(
219219
state_dict = quant_unified_optimizer(
220220
state_dict, state_dict_type, ckpt_quant_stage, async_save=True
221221
) # ckpt quantization
222-
metadata = {"format": "pt"} if save_to_hf else {"format": "paddle"}
222+
metadata = {"format": "pt"} if save_to_hf else {"format": "np"}
223223
safe_save_file(state_dict, path, metadata=metadata)
224224
del state_dict
225225
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")

paddleformers/trainer/unified_checkpoint/load_save_single_card.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from ...utils.nested import nested_copy
4242

4343
if is_safetensors_available():
44-
from safetensors.paddle import save_file as safe_save_file
44+
from safetensors.numpy import save_file as safe_save_file
4545

4646
from .utils import (
4747
FP32_MASTER,

paddleformers/transformers/model_utils.py

Lines changed: 91 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import json
2222
import os
2323
import re
24+
import sys
2425
import tempfile
2526
import warnings
2627
from contextlib import contextmanager
@@ -29,6 +30,7 @@
2930
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
3031

3132
import aistudio_sdk
33+
import ml_dtypes
3234
import numpy as np
3335
import paddle
3436
import paddle.nn as nn
@@ -126,9 +128,14 @@ def unwrap_optimizer(optimizer, optimizer_instances=()):
126128

127129

128130
if is_safetensors_available():
129-
from safetensors import safe_open
130-
from safetensors.paddle import load_file as safe_load_file
131-
from safetensors.paddle import save_file as safe_save_file
131+
from safetensors.numpy import save_file as safe_save_file
132+
133+
from ..utils.safetensors import fast_load_file as safe_load_file
134+
135+
if sys.platform.startswith("win"):
136+
from safetensors import safe_open
137+
else:
138+
from ..utils.safetensors import fast_safe_open as safe_open
132139

133140

134141
def prune_linear_layer(layer: nn.Linear, index: paddle.Tensor, dim: int = 0) -> nn.Linear:
@@ -395,7 +402,7 @@ def _transpose_hf_weight(key, weight):
395402

396403
part_state_dict = {}
397404
scale_dict = {}
398-
with safe_open(checkpoint_file, framework="paddle") as f:
405+
with safe_open(checkpoint_file, framework="np") as f:
399406
for key in keys:
400407
# 1. non-merge ckpt loading dont have filter key.
401408
# 2. merge ckpt will skip quant scale by `fliter_dict_keys`
@@ -415,7 +422,8 @@ def _transpose_hf_weight(key, weight):
415422
and key.split(".weight")[0] in quantization_linear_list
416423
and not key.endswith("_scale")
417424
):
418-
weight = py_safe_slice_[:]
425+
# numpy.array -> paddle.tensor
426+
weight = paddle.Tensor.__call__(py_safe_slice_[:], zero_copy=True)
419427
weight = _transpose_hf_weight(key, weight)
420428
key_name = key.split(".weight")[0]
421429
quant_key_name = key_name + ".quant_weight"
@@ -450,17 +458,19 @@ def _transpose_hf_weight(key, weight):
450458
is_column = not is_column
451459
tp_fn = partial(tp_fn.func, *tp_fn.args, **{**tp_fn.keywords, "is_column": is_column})
452460
if len(py_safe_slice_.shape) == 0:
453-
weight = tp_fn(py_safe_slice_[:])
461+
weight = tp_fn(py_safe_slice_.get())
454462
else:
455463
weight = tp_fn(py_safe_slice_)
456464
else:
457-
weight = py_safe_slice_[:]
458-
465+
if len(py_safe_slice_.shape) == 0:
466+
weight = py_safe_slice_.get()
467+
else:
468+
weight = py_safe_slice_[:]
459469
if not return_numpy and device == "expected":
470+
with device_guard():
471+
weight = paddle.Tensor.__call__(weight, zero_copy=True)
460472
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
461473
weight = _transpose_hf_weight(key, weight)
462-
if return_numpy:
463-
weight = weight.numpy()
464474
part_state_dict[key] = weight
465475

466476
for key in keys:
@@ -471,9 +481,9 @@ def _transpose_hf_weight(key, weight):
471481
):
472482
scale = f.get_tensor(key)
473483
if not return_numpy and device == "expected":
484+
with device_guard():
485+
scale = paddle.Tensor.__call__(scale, zero_copy=True)
474486
scale = scale._copy_to(paddle.framework._current_expected_place(), False)
475-
if return_numpy:
476-
scale = scale.numpy()
477487
scale_dict[key] = scale
478488
return part_state_dict, scale_dict
479489

@@ -501,34 +511,26 @@ def load_state_dict(
501511
if (
502512
checkpoint_file.endswith(".safetensors") or re.search(r"\.safetensors_shard_\d{4}$", checkpoint_file)
503513
) and is_safetensors_available():
504-
thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1"))
505-
if thread_num > 1:
506-
logger.info(f"Set loading state_dict thread num to {thread_num}")
507-
state_dict, scale_dict = {}, {}
508-
if thread_num <= 1:
509-
with safe_open(checkpoint_file, framework="paddle") as f:
510-
state_dict, scale_dict = _load_part_state_dict(
511-
list(f.keys()),
512-
checkpoint_file,
513-
tensor_parallel_split_mapping,
514-
fliter_dict_keys,
515-
device,
516-
quantization_linear_list,
517-
quantization_config,
518-
dtype,
519-
return_numpy,
520-
convert_from_hf,
521-
transpose_weight_keys,
522-
)
523-
else:
524-
# Load state dict in multi-thread to speed up loading
525-
with safe_open(checkpoint_file, framework="paddle") as f:
526-
keys_groups = _split_keys_evenly(list(f.keys()), thread_num)
527-
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor:
528-
future_to_key = {
529-
executor.submit(
530-
_load_part_state_dict,
531-
keys,
514+
# Check format of the archive
515+
with safe_open(checkpoint_file, framework="np") as f:
516+
metadata = {"format": "np"}
517+
518+
if metadata.get("format", "np") not in ["pd", "np"]:
519+
raise OSError(
520+
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
521+
"you save your model with the `save_pretrained` method."
522+
)
523+
if metadata.get("format", "np") == "pd":
524+
raise ValueError("Currently unsupport paddle weights file, use numpy instead.")
525+
if metadata.get("format", "np") == "np":
526+
thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1"))
527+
if thread_num > 1:
528+
logger.info(f"Set loading state_dict thread num to {thread_num}")
529+
state_dict, scale_dict = {}, {}
530+
if thread_num <= 1:
531+
with safe_open(checkpoint_file, framework="np") as f:
532+
state_dict, scale_dict = _load_part_state_dict(
533+
list(f.keys()),
532534
checkpoint_file,
533535
tensor_parallel_split_mapping,
534536
fliter_dict_keys,
@@ -539,41 +541,54 @@ def load_state_dict(
539541
return_numpy,
540542
convert_from_hf,
541543
transpose_weight_keys,
542-
): keys
543-
for keys in keys_groups
544-
}
545-
for future in concurrent.futures.as_completed(future_to_key):
546-
res_state_dict, res_scale_dict = future.result()
547-
state_dict.update(res_state_dict)
548-
scale_dict.update(res_scale_dict)
549-
550-
if not return_numpy:
551-
if device == "pin_memory":
552-
for k in list(state_dict.keys()):
553-
pd_tensor = state_dict.pop(k)
554-
state_dict[k] = (
555-
pd_tensor
556-
if pd_tensor.place == paddle.CUDAPinnedPlace()
557-
else pd_tensor.to(paddle.CUDAPinnedPlace())
558544
)
559-
else:
560-
for k in list(state_dict.keys()):
561-
state_dict[k] = state_dict.pop(k).numpy()
545+
else:
546+
# Load state dict in multi-thread to speed up loading
547+
with safe_open(checkpoint_file, framework="np") as f:
548+
keys_groups = _split_keys_evenly(list(f.keys()), thread_num)
549+
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor:
550+
future_to_key = {
551+
executor.submit(
552+
_load_part_state_dict,
553+
keys,
554+
checkpoint_file,
555+
tensor_parallel_split_mapping,
556+
fliter_dict_keys,
557+
device,
558+
quantization_linear_list,
559+
quantization_config,
560+
dtype,
561+
return_numpy,
562+
convert_from_hf,
563+
transpose_weight_keys,
564+
): keys
565+
for keys in keys_groups
566+
}
567+
for future in concurrent.futures.as_completed(future_to_key):
568+
res_state_dict, res_scale_dict = future.result()
569+
state_dict.update(res_state_dict)
570+
scale_dict.update(res_scale_dict)
571+
572+
if not return_numpy:
573+
if device == "cpu":
574+
with device_guard():
575+
for k in list(state_dict.keys()):
576+
state_dict[k] = paddle.Tensor.__call__(state_dict.pop(k), zero_copy=True)
577+
elif device == "pin_memory":
578+
for k in list(state_dict.keys()):
579+
state_dict[k] = paddle.to_tensor(state_dict.pop(k), place=paddle.CUDAPinnedPlace())
562580

563-
if len(scale_dict) != 0:
564-
if ckpt_quant_stage == "O0":
565-
raise ValueError('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"')
566-
state_dict = dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict, use_pd=True)
581+
if len(scale_dict) != 0:
582+
if ckpt_quant_stage == "O0":
583+
raise ValueError('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"')
584+
state_dict = dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict, use_pd=True)
567585

568-
return state_dict
586+
return state_dict
569587

570588
# load from hf but not safetensors checkpoint
571589
if convert_from_hf:
572590
state_dict = load_torch(checkpoint_file)
573591
state_dict = ConversionMixin.convert_transpose_selected_weights(state_dict, transpose_weight_keys)
574-
if return_numpy:
575-
for k in list(state_dict.keys()):
576-
state_dict[k] = state_dict.pop(k).numpy()
577592
return state_dict
578593

579594
state_dict = paddleformers_load(checkpoint_file, map_location="cpu")
@@ -584,8 +599,10 @@ def prepare_safe_save_state_dict(state_dict, save_to_hf=False):
584599
for k in list(state_dict.keys()):
585600
if isinstance(state_dict[k], paddle.Tensor):
586601
if state_dict[k].dtype == paddle.bfloat16:
587-
state_dict[k] = state_dict.pop(k).contiguous().astype(paddle.bfloat16)
588-
metadata = {"format": "pt"} if save_to_hf else {"format": "paddle"}
602+
state_dict[k] = state_dict.pop(k).astype("float32").cpu().numpy().astype(ml_dtypes.bfloat16)
603+
else:
604+
state_dict[k] = state_dict.pop(k).cpu().numpy()
605+
metadata = {"format": "pt"} if save_to_hf else {"format": "np"}
589606
return state_dict, metadata
590607

591608

@@ -2034,6 +2051,7 @@ def get_file_path(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME, v
20342051
f"Error no files {filenames} found in repo {pretrained_model_name_or_path}."
20352052
)
20362053
elif "pytorch_model.bin" in str(resolved_archive_file):
2054+
20372055
if download_hub == DownloadSource.AISTUDIO and not convert_from_hf:
20382056
raise ValueError(
20392057
f"Download pytorch weight in "
@@ -2614,7 +2632,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
26142632
logger.warning("`load_state_as_np` is deprecated, please delete it!")
26152633

26162634
model_kwargs = kwargs
2635+
26172636
if convert_from_hf is None and download_hub == DownloadSource.MODELSCOPE:
2637+
26182638
logger.warning(
26192639
"If you are attempting to load weights from ModelScope Hub and want to disable the default behavior of considering torch weights,"
26202640
" you can set ·convert_from_hf=False·. By default, `convert_from_hf` is set to `True`. "
@@ -2687,7 +2707,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
26872707
if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model_state.pdparams"):
26882708
state_dict = cls.convert_tensor_parallel(resolved_archive_file, config)
26892709
elif config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model.safetensors"):
2690-
with safe_open(resolved_archive_file, framework="paddle", device="cpu") as f:
2710+
with safe_open(resolved_archive_file, framework="np", device="cpu") as f:
26912711
loaded_keys = f.keys()
26922712
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys)
26932713
state_dict = load_state_dict(
@@ -3332,7 +3352,7 @@ def load_tp_checkpoint(folder, cls, config, return_numpy=False, convert_from_hf=
33323352
elif os.path.exists(model_path):
33333353
state_dict = cls.convert_tensor_parallel(model_path, config)
33343354
elif os.path.exists(safe_model_path):
3335-
with safe_open(safe_model_path, framework="paddle", device="cpu") as f:
3355+
with safe_open(safe_model_path, framework="np", device="cpu") as f:
33363356
loaded_keys = f.keys()
33373357
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys)
33383358
state_dict = load_state_dict(

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ sentencepiece
88
huggingface_hub>=0.19.2
99
protobuf>=3.20.2
1010
visualdl
11-
safetensors @ https://paddle-whl.bj.bcebos.com/nightly/cu126/safetensors/safetensors-0.6.2.dev0-cp38-abi3-linux_x86_64.whl
11+
safetensors
1212
fast_dataindex>=0.1.1 ; platform_system == "Linux"
1313
aistudio-sdk>=0.3.0
1414
jinja2

0 commit comments

Comments
 (0)