Skip to content

Commit e54947d

Browse files
authored
Support safetensors.paddle without coverting to numpy (#2538)
Co-authored-by: llbdyiu66 <[email protected]>
1 parent ebfbac6 commit e54947d

File tree

4 files changed

+75
-95
lines changed

4 files changed

+75
-95
lines changed

paddleformers/trainer/unified_checkpoint/async_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ...utils.log import logger
2727

2828
if is_safetensors_available():
29-
from safetensors.numpy import save_file as safe_save_file
29+
from safetensors.paddle import save_file as safe_save_file
3030

3131
from ...quantization.unified_checkpoint_quantization import quant_unified_optimizer
3232
from .shared_memory_utils import (
@@ -219,7 +219,7 @@ def _save_file_async_in_process(
219219
state_dict = quant_unified_optimizer(
220220
state_dict, state_dict_type, ckpt_quant_stage, async_save=True
221221
) # ckpt quantization
222-
metadata = {"format": "pt"} if save_to_hf else {"format": "np"}
222+
metadata = {"format": "pt"} if save_to_hf else {"format": "paddle"}
223223
safe_save_file(state_dict, path, metadata=metadata)
224224
del state_dict
225225
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")

paddleformers/trainer/unified_checkpoint/load_save_single_card.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from ...utils.nested import nested_copy
4242

4343
if is_safetensors_available():
44-
from safetensors.numpy import save_file as safe_save_file
44+
from safetensors.paddle import save_file as safe_save_file
4545

4646
from .utils import (
4747
FP32_MASTER,

paddleformers/transformers/model_utils.py

Lines changed: 71 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import json
2222
import os
2323
import re
24-
import sys
2524
import tempfile
2625
import warnings
2726
from contextlib import contextmanager
@@ -30,7 +29,6 @@
3029
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
3130

3231
import aistudio_sdk
33-
import ml_dtypes
3432
import numpy as np
3533
import paddle
3634
import paddle.nn as nn
@@ -128,14 +126,9 @@ def unwrap_optimizer(optimizer, optimizer_instances=()):
128126

129127

130128
if is_safetensors_available():
131-
from safetensors.numpy import save_file as safe_save_file
132-
133-
from ..utils.safetensors import fast_load_file as safe_load_file
134-
135-
if sys.platform.startswith("win"):
136-
from safetensors import safe_open
137-
else:
138-
from ..utils.safetensors import fast_safe_open as safe_open
129+
from safetensors import safe_open
130+
from safetensors.paddle import load_file as safe_load_file
131+
from safetensors.paddle import save_file as safe_save_file
139132

140133

141134
def prune_linear_layer(layer: nn.Linear, index: paddle.Tensor, dim: int = 0) -> nn.Linear:
@@ -402,7 +395,7 @@ def _transpose_hf_weight(key, weight):
402395

403396
part_state_dict = {}
404397
scale_dict = {}
405-
with safe_open(checkpoint_file, framework="np") as f:
398+
with safe_open(checkpoint_file, framework="paddle") as f:
406399
for key in keys:
407400
# 1. non-merge ckpt loading dont have filter key.
408401
# 2. merge ckpt will skip quant scale by `fliter_dict_keys`
@@ -422,8 +415,7 @@ def _transpose_hf_weight(key, weight):
422415
and key.split(".weight")[0] in quantization_linear_list
423416
and not key.endswith("_scale")
424417
):
425-
# numpy.array -> paddle.tensor
426-
weight = paddle.Tensor.__call__(py_safe_slice_[:], zero_copy=True)
418+
weight = py_safe_slice_[:]
427419
weight = _transpose_hf_weight(key, weight)
428420
key_name = key.split(".weight")[0]
429421
quant_key_name = key_name + ".quant_weight"
@@ -458,19 +450,17 @@ def _transpose_hf_weight(key, weight):
458450
is_column = not is_column
459451
tp_fn = partial(tp_fn.func, *tp_fn.args, **{**tp_fn.keywords, "is_column": is_column})
460452
if len(py_safe_slice_.shape) == 0:
461-
weight = tp_fn(py_safe_slice_.get())
453+
weight = tp_fn(py_safe_slice_[:])
462454
else:
463455
weight = tp_fn(py_safe_slice_)
464456
else:
465-
if len(py_safe_slice_.shape) == 0:
466-
weight = py_safe_slice_.get()
467-
else:
468-
weight = py_safe_slice_[:]
457+
weight = py_safe_slice_[:]
458+
469459
if not return_numpy and device == "expected":
470-
with device_guard():
471-
weight = paddle.Tensor.__call__(weight, zero_copy=True)
472460
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
473461
weight = _transpose_hf_weight(key, weight)
462+
if return_numpy:
463+
weight = weight.numpy()
474464
part_state_dict[key] = weight
475465

476466
for key in keys:
@@ -481,9 +471,9 @@ def _transpose_hf_weight(key, weight):
481471
):
482472
scale = f.get_tensor(key)
483473
if not return_numpy and device == "expected":
484-
with device_guard():
485-
scale = paddle.Tensor.__call__(scale, zero_copy=True)
486474
scale = scale._copy_to(paddle.framework._current_expected_place(), False)
475+
if return_numpy:
476+
scale = scale.numpy()
487477
scale_dict[key] = scale
488478
return part_state_dict, scale_dict
489479

@@ -511,26 +501,34 @@ def load_state_dict(
511501
if (
512502
checkpoint_file.endswith(".safetensors") or re.search(r"\.safetensors_shard_\d{4}$", checkpoint_file)
513503
) and is_safetensors_available():
514-
# Check format of the archive
515-
with safe_open(checkpoint_file, framework="np") as f:
516-
metadata = {"format": "np"}
517-
518-
if metadata.get("format", "np") not in ["pd", "np"]:
519-
raise OSError(
520-
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
521-
"you save your model with the `save_pretrained` method."
522-
)
523-
if metadata.get("format", "np") == "pd":
524-
raise ValueError("Currently unsupport paddle weights file, use numpy instead.")
525-
if metadata.get("format", "np") == "np":
526-
thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1"))
527-
if thread_num > 1:
528-
logger.info(f"Set loading state_dict thread num to {thread_num}")
529-
state_dict, scale_dict = {}, {}
530-
if thread_num <= 1:
531-
with safe_open(checkpoint_file, framework="np") as f:
532-
state_dict, scale_dict = _load_part_state_dict(
533-
list(f.keys()),
504+
thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1"))
505+
if thread_num > 1:
506+
logger.info(f"Set loading state_dict thread num to {thread_num}")
507+
state_dict, scale_dict = {}, {}
508+
if thread_num <= 1:
509+
with safe_open(checkpoint_file, framework="paddle") as f:
510+
state_dict, scale_dict = _load_part_state_dict(
511+
list(f.keys()),
512+
checkpoint_file,
513+
tensor_parallel_split_mapping,
514+
fliter_dict_keys,
515+
device,
516+
quantization_linear_list,
517+
quantization_config,
518+
dtype,
519+
return_numpy,
520+
convert_from_hf,
521+
transpose_weight_keys,
522+
)
523+
else:
524+
# Load state dict in multi-thread to speed up loading
525+
with safe_open(checkpoint_file, framework="paddle") as f:
526+
keys_groups = _split_keys_evenly(list(f.keys()), thread_num)
527+
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor:
528+
future_to_key = {
529+
executor.submit(
530+
_load_part_state_dict,
531+
keys,
534532
checkpoint_file,
535533
tensor_parallel_split_mapping,
536534
fliter_dict_keys,
@@ -541,54 +539,41 @@ def load_state_dict(
541539
return_numpy,
542540
convert_from_hf,
543541
transpose_weight_keys,
542+
): keys
543+
for keys in keys_groups
544+
}
545+
for future in concurrent.futures.as_completed(future_to_key):
546+
res_state_dict, res_scale_dict = future.result()
547+
state_dict.update(res_state_dict)
548+
scale_dict.update(res_scale_dict)
549+
550+
if not return_numpy:
551+
if device == "pin_memory":
552+
for k in list(state_dict.keys()):
553+
pd_tensor = state_dict.pop(k)
554+
state_dict[k] = (
555+
pd_tensor
556+
if pd_tensor.place == paddle.CUDAPinnedPlace()
557+
else pd_tensor.to(paddle.CUDAPinnedPlace())
544558
)
545-
else:
546-
# Load state dict in multi-thread to speed up loading
547-
with safe_open(checkpoint_file, framework="np") as f:
548-
keys_groups = _split_keys_evenly(list(f.keys()), thread_num)
549-
with concurrent.futures.ThreadPoolExecutor(max_workers=thread_num) as executor:
550-
future_to_key = {
551-
executor.submit(
552-
_load_part_state_dict,
553-
keys,
554-
checkpoint_file,
555-
tensor_parallel_split_mapping,
556-
fliter_dict_keys,
557-
device,
558-
quantization_linear_list,
559-
quantization_config,
560-
dtype,
561-
return_numpy,
562-
convert_from_hf,
563-
transpose_weight_keys,
564-
): keys
565-
for keys in keys_groups
566-
}
567-
for future in concurrent.futures.as_completed(future_to_key):
568-
res_state_dict, res_scale_dict = future.result()
569-
state_dict.update(res_state_dict)
570-
scale_dict.update(res_scale_dict)
571-
572-
if not return_numpy:
573-
if device == "cpu":
574-
with device_guard():
575-
for k in list(state_dict.keys()):
576-
state_dict[k] = paddle.Tensor.__call__(state_dict.pop(k), zero_copy=True)
577-
elif device == "pin_memory":
578-
for k in list(state_dict.keys()):
579-
state_dict[k] = paddle.to_tensor(state_dict.pop(k), place=paddle.CUDAPinnedPlace())
559+
else:
560+
for k in list(state_dict.keys()):
561+
state_dict[k] = state_dict.pop(k).numpy()
580562

581-
if len(scale_dict) != 0:
582-
if ckpt_quant_stage == "O0":
583-
raise ValueError('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"')
584-
state_dict = dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict, use_pd=True)
563+
if len(scale_dict) != 0:
564+
if ckpt_quant_stage == "O0":
565+
raise ValueError('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"')
566+
state_dict = dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict, use_pd=True)
585567

586-
return state_dict
568+
return state_dict
587569

588570
# load from hf but not safetensors checkpoint
589571
if convert_from_hf:
590572
state_dict = load_torch(checkpoint_file)
591573
state_dict = ConversionMixin.convert_transpose_selected_weights(state_dict, transpose_weight_keys)
574+
if return_numpy:
575+
for k in list(state_dict.keys()):
576+
state_dict[k] = state_dict.pop(k).numpy()
592577
return state_dict
593578

594579
state_dict = paddleformers_load(checkpoint_file, map_location="cpu")
@@ -599,10 +584,8 @@ def prepare_safe_save_state_dict(state_dict, save_to_hf=False):
599584
for k in list(state_dict.keys()):
600585
if isinstance(state_dict[k], paddle.Tensor):
601586
if state_dict[k].dtype == paddle.bfloat16:
602-
state_dict[k] = state_dict.pop(k).astype("float32").cpu().numpy().astype(ml_dtypes.bfloat16)
603-
else:
604-
state_dict[k] = state_dict.pop(k).cpu().numpy()
605-
metadata = {"format": "pt"} if save_to_hf else {"format": "np"}
587+
state_dict[k] = state_dict.pop(k).contiguous().astype(paddle.bfloat16)
588+
metadata = {"format": "pt"} if save_to_hf else {"format": "paddle"}
606589
return state_dict, metadata
607590

608591

@@ -2051,7 +2034,6 @@ def get_file_path(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME, v
20512034
f"Error no files {filenames} found in repo {pretrained_model_name_or_path}."
20522035
)
20532036
elif "pytorch_model.bin" in str(resolved_archive_file):
2054-
20552037
if download_hub == DownloadSource.AISTUDIO and not convert_from_hf:
20562038
raise ValueError(
20572039
f"Download pytorch weight in "
@@ -2632,9 +2614,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
26322614
logger.warning("`load_state_as_np` is deprecated, please delete it!")
26332615

26342616
model_kwargs = kwargs
2635-
26362617
if convert_from_hf is None and download_hub == DownloadSource.MODELSCOPE:
2637-
26382618
logger.warning(
26392619
"If you are attempting to load weights from ModelScope Hub and want to disable the default behavior of considering torch weights,"
26402620
" you can set ·convert_from_hf=False·. By default, `convert_from_hf` is set to `True`. "
@@ -2707,7 +2687,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
27072687
if config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model_state.pdparams"):
27082688
state_dict = cls.convert_tensor_parallel(resolved_archive_file, config)
27092689
elif config.tensor_parallel_degree > 1 and resolved_archive_file.endswith("model.safetensors"):
2710-
with safe_open(resolved_archive_file, framework="np", device="cpu") as f:
2690+
with safe_open(resolved_archive_file, framework="paddle", device="cpu") as f:
27112691
loaded_keys = f.keys()
27122692
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys)
27132693
state_dict = load_state_dict(
@@ -3352,7 +3332,7 @@ def load_tp_checkpoint(folder, cls, config, return_numpy=False, convert_from_hf=
33523332
elif os.path.exists(model_path):
33533333
state_dict = cls.convert_tensor_parallel(model_path, config)
33543334
elif os.path.exists(safe_model_path):
3355-
with safe_open(safe_model_path, framework="np", device="cpu") as f:
3335+
with safe_open(safe_model_path, framework="paddle", device="cpu") as f:
33563336
loaded_keys = f.keys()
33573337
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys)
33583338
state_dict = load_state_dict(

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ sentencepiece
88
huggingface_hub>=0.19.2
99
protobuf>=3.20.2
1010
visualdl
11-
safetensors
11+
safetensors @ https://paddle-whl.bj.bcebos.com/nightly/cu126/safetensors/safetensors-0.6.2.dev0-cp38-abi3-linux_x86_64.whl
1212
fast_dataindex>=0.1.1 ; platform_system == "Linux"
1313
aistudio-sdk>=0.3.0
1414
jinja2

0 commit comments

Comments
 (0)