21
21
import json
22
22
import os
23
23
import re
24
+ import sys
24
25
import tempfile
25
26
import warnings
26
27
from contextlib import contextmanager
29
30
from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , Union
30
31
31
32
import aistudio_sdk
33
+ import ml_dtypes
32
34
import numpy as np
33
35
import paddle
34
36
import paddle .nn as nn
@@ -126,9 +128,14 @@ def unwrap_optimizer(optimizer, optimizer_instances=()):
126
128
127
129
128
130
if is_safetensors_available ():
129
- from safetensors import safe_open
130
- from safetensors .paddle import load_file as safe_load_file
131
- from safetensors .paddle import save_file as safe_save_file
131
+ from safetensors .numpy import save_file as safe_save_file
132
+
133
+ from ..utils .safetensors import fast_load_file as safe_load_file
134
+
135
+ if sys .platform .startswith ("win" ):
136
+ from safetensors import safe_open
137
+ else :
138
+ from ..utils .safetensors import fast_safe_open as safe_open
132
139
133
140
134
141
def prune_linear_layer (layer : nn .Linear , index : paddle .Tensor , dim : int = 0 ) -> nn .Linear :
@@ -395,7 +402,7 @@ def _transpose_hf_weight(key, weight):
395
402
396
403
part_state_dict = {}
397
404
scale_dict = {}
398
- with safe_open (checkpoint_file , framework = "paddle " ) as f :
405
+ with safe_open (checkpoint_file , framework = "np " ) as f :
399
406
for key in keys :
400
407
# 1. non-merge ckpt loading dont have filter key.
401
408
# 2. merge ckpt will skip quant scale by `fliter_dict_keys`
@@ -415,7 +422,8 @@ def _transpose_hf_weight(key, weight):
415
422
and key .split (".weight" )[0 ] in quantization_linear_list
416
423
and not key .endswith ("_scale" )
417
424
):
418
- weight = py_safe_slice_ [:]
425
+ # numpy.array -> paddle.tensor
426
+ weight = paddle .Tensor .__call__ (py_safe_slice_ [:], zero_copy = True )
419
427
weight = _transpose_hf_weight (key , weight )
420
428
key_name = key .split (".weight" )[0 ]
421
429
quant_key_name = key_name + ".quant_weight"
@@ -450,17 +458,19 @@ def _transpose_hf_weight(key, weight):
450
458
is_column = not is_column
451
459
tp_fn = partial (tp_fn .func , * tp_fn .args , ** {** tp_fn .keywords , "is_column" : is_column })
452
460
if len (py_safe_slice_ .shape ) == 0 :
453
- weight = tp_fn (py_safe_slice_ [:] )
461
+ weight = tp_fn (py_safe_slice_ . get () )
454
462
else :
455
463
weight = tp_fn (py_safe_slice_ )
456
464
else :
457
- weight = py_safe_slice_ [:]
458
-
465
+ if len (py_safe_slice_ .shape ) == 0 :
466
+ weight = py_safe_slice_ .get ()
467
+ else :
468
+ weight = py_safe_slice_ [:]
459
469
if not return_numpy and device == "expected" :
470
+ with device_guard ():
471
+ weight = paddle .Tensor .__call__ (weight , zero_copy = True )
460
472
weight = weight ._copy_to (paddle .framework ._current_expected_place (), False )
461
473
weight = _transpose_hf_weight (key , weight )
462
- if return_numpy :
463
- weight = weight .numpy ()
464
474
part_state_dict [key ] = weight
465
475
466
476
for key in keys :
@@ -471,9 +481,9 @@ def _transpose_hf_weight(key, weight):
471
481
):
472
482
scale = f .get_tensor (key )
473
483
if not return_numpy and device == "expected" :
484
+ with device_guard ():
485
+ scale = paddle .Tensor .__call__ (scale , zero_copy = True )
474
486
scale = scale ._copy_to (paddle .framework ._current_expected_place (), False )
475
- if return_numpy :
476
- scale = scale .numpy ()
477
487
scale_dict [key ] = scale
478
488
return part_state_dict , scale_dict
479
489
@@ -501,34 +511,26 @@ def load_state_dict(
501
511
if (
502
512
checkpoint_file .endswith (".safetensors" ) or re .search (r"\.safetensors_shard_\d{4}$" , checkpoint_file )
503
513
) and is_safetensors_available ():
504
- thread_num = int (os .environ .get ("LOAD_STATE_DICT_THREAD_NUM" , "1" ))
505
- if thread_num > 1 :
506
- logger .info (f"Set loading state_dict thread num to { thread_num } " )
507
- state_dict , scale_dict = {}, {}
508
- if thread_num <= 1 :
509
- with safe_open (checkpoint_file , framework = "paddle" ) as f :
510
- state_dict , scale_dict = _load_part_state_dict (
511
- list (f .keys ()),
512
- checkpoint_file ,
513
- tensor_parallel_split_mapping ,
514
- fliter_dict_keys ,
515
- device ,
516
- quantization_linear_list ,
517
- quantization_config ,
518
- dtype ,
519
- return_numpy ,
520
- convert_from_hf ,
521
- transpose_weight_keys ,
522
- )
523
- else :
524
- # Load state dict in multi-thread to speed up loading
525
- with safe_open (checkpoint_file , framework = "paddle" ) as f :
526
- keys_groups = _split_keys_evenly (list (f .keys ()), thread_num )
527
- with concurrent .futures .ThreadPoolExecutor (max_workers = thread_num ) as executor :
528
- future_to_key = {
529
- executor .submit (
530
- _load_part_state_dict ,
531
- keys ,
514
+ # Check format of the archive
515
+ with safe_open (checkpoint_file , framework = "np" ) as f :
516
+ metadata = {"format" : "np" }
517
+
518
+ if metadata .get ("format" , "np" ) not in ["pd" , "np" ]:
519
+ raise OSError (
520
+ f"The safetensors archive passed at { checkpoint_file } does not contain the valid metadata. Make sure "
521
+ "you save your model with the `save_pretrained` method."
522
+ )
523
+ if metadata .get ("format" , "np" ) == "pd" :
524
+ raise ValueError ("Currently unsupport paddle weights file, use numpy instead." )
525
+ if metadata .get ("format" , "np" ) == "np" :
526
+ thread_num = int (os .environ .get ("LOAD_STATE_DICT_THREAD_NUM" , "1" ))
527
+ if thread_num > 1 :
528
+ logger .info (f"Set loading state_dict thread num to { thread_num } " )
529
+ state_dict , scale_dict = {}, {}
530
+ if thread_num <= 1 :
531
+ with safe_open (checkpoint_file , framework = "np" ) as f :
532
+ state_dict , scale_dict = _load_part_state_dict (
533
+ list (f .keys ()),
532
534
checkpoint_file ,
533
535
tensor_parallel_split_mapping ,
534
536
fliter_dict_keys ,
@@ -539,41 +541,54 @@ def load_state_dict(
539
541
return_numpy ,
540
542
convert_from_hf ,
541
543
transpose_weight_keys ,
542
- ): keys
543
- for keys in keys_groups
544
- }
545
- for future in concurrent .futures .as_completed (future_to_key ):
546
- res_state_dict , res_scale_dict = future .result ()
547
- state_dict .update (res_state_dict )
548
- scale_dict .update (res_scale_dict )
549
-
550
- if not return_numpy :
551
- if device == "pin_memory" :
552
- for k in list (state_dict .keys ()):
553
- pd_tensor = state_dict .pop (k )
554
- state_dict [k ] = (
555
- pd_tensor
556
- if pd_tensor .place == paddle .CUDAPinnedPlace ()
557
- else pd_tensor .to (paddle .CUDAPinnedPlace ())
558
544
)
559
- else :
560
- for k in list (state_dict .keys ()):
561
- state_dict [k ] = state_dict .pop (k ).numpy ()
545
+ else :
546
+ # Load state dict in multi-thread to speed up loading
547
+ with safe_open (checkpoint_file , framework = "np" ) as f :
548
+ keys_groups = _split_keys_evenly (list (f .keys ()), thread_num )
549
+ with concurrent .futures .ThreadPoolExecutor (max_workers = thread_num ) as executor :
550
+ future_to_key = {
551
+ executor .submit (
552
+ _load_part_state_dict ,
553
+ keys ,
554
+ checkpoint_file ,
555
+ tensor_parallel_split_mapping ,
556
+ fliter_dict_keys ,
557
+ device ,
558
+ quantization_linear_list ,
559
+ quantization_config ,
560
+ dtype ,
561
+ return_numpy ,
562
+ convert_from_hf ,
563
+ transpose_weight_keys ,
564
+ ): keys
565
+ for keys in keys_groups
566
+ }
567
+ for future in concurrent .futures .as_completed (future_to_key ):
568
+ res_state_dict , res_scale_dict = future .result ()
569
+ state_dict .update (res_state_dict )
570
+ scale_dict .update (res_scale_dict )
571
+
572
+ if not return_numpy :
573
+ if device == "cpu" :
574
+ with device_guard ():
575
+ for k in list (state_dict .keys ()):
576
+ state_dict [k ] = paddle .Tensor .__call__ (state_dict .pop (k ), zero_copy = True )
577
+ elif device == "pin_memory" :
578
+ for k in list (state_dict .keys ()):
579
+ state_dict [k ] = paddle .to_tensor (state_dict .pop (k ), place = paddle .CUDAPinnedPlace ())
562
580
563
- if len (scale_dict ) != 0 :
564
- if ckpt_quant_stage == "O0" :
565
- raise ValueError ('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"' )
566
- state_dict = dequant_unified_optimizer (state_dict , ckpt_quant_stage , scale_dict , use_pd = True )
581
+ if len (scale_dict ) != 0 :
582
+ if ckpt_quant_stage == "O0" :
583
+ raise ValueError ('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"' )
584
+ state_dict = dequant_unified_optimizer (state_dict , ckpt_quant_stage , scale_dict , use_pd = True )
567
585
568
- return state_dict
586
+ return state_dict
569
587
570
588
# load from hf but not safetensors checkpoint
571
589
if convert_from_hf :
572
590
state_dict = load_torch (checkpoint_file )
573
591
state_dict = ConversionMixin .convert_transpose_selected_weights (state_dict , transpose_weight_keys )
574
- if return_numpy :
575
- for k in list (state_dict .keys ()):
576
- state_dict [k ] = state_dict .pop (k ).numpy ()
577
592
return state_dict
578
593
579
594
state_dict = paddleformers_load (checkpoint_file , map_location = "cpu" )
@@ -584,8 +599,10 @@ def prepare_safe_save_state_dict(state_dict, save_to_hf=False):
584
599
for k in list (state_dict .keys ()):
585
600
if isinstance (state_dict [k ], paddle .Tensor ):
586
601
if state_dict [k ].dtype == paddle .bfloat16 :
587
- state_dict [k ] = state_dict .pop (k ).contiguous ().astype (paddle .bfloat16 )
588
- metadata = {"format" : "pt" } if save_to_hf else {"format" : "paddle" }
602
+ state_dict [k ] = state_dict .pop (k ).astype ("float32" ).cpu ().numpy ().astype (ml_dtypes .bfloat16 )
603
+ else :
604
+ state_dict [k ] = state_dict .pop (k ).cpu ().numpy ()
605
+ metadata = {"format" : "pt" } if save_to_hf else {"format" : "np" }
589
606
return state_dict , metadata
590
607
591
608
@@ -2034,6 +2051,7 @@ def get_file_path(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME, v
2034
2051
f"Error no files { filenames } found in repo { pretrained_model_name_or_path } ."
2035
2052
)
2036
2053
elif "pytorch_model.bin" in str (resolved_archive_file ):
2054
+
2037
2055
if download_hub == DownloadSource .AISTUDIO and not convert_from_hf :
2038
2056
raise ValueError (
2039
2057
f"Download pytorch weight in "
@@ -2614,7 +2632,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
2614
2632
logger .warning ("`load_state_as_np` is deprecated, please delete it!" )
2615
2633
2616
2634
model_kwargs = kwargs
2635
+
2617
2636
if convert_from_hf is None and download_hub == DownloadSource .MODELSCOPE :
2637
+
2618
2638
logger .warning (
2619
2639
"If you are attempting to load weights from ModelScope Hub and want to disable the default behavior of considering torch weights,"
2620
2640
" you can set ·convert_from_hf=False·. By default, `convert_from_hf` is set to `True`. "
@@ -2687,7 +2707,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
2687
2707
if config .tensor_parallel_degree > 1 and resolved_archive_file .endswith ("model_state.pdparams" ):
2688
2708
state_dict = cls .convert_tensor_parallel (resolved_archive_file , config )
2689
2709
elif config .tensor_parallel_degree > 1 and resolved_archive_file .endswith ("model.safetensors" ):
2690
- with safe_open (resolved_archive_file , framework = "paddle " , device = "cpu" ) as f :
2710
+ with safe_open (resolved_archive_file , framework = "np " , device = "cpu" ) as f :
2691
2711
loaded_keys = f .keys ()
2692
2712
tp_actions = cls .get_tensor_parallel_convert_actions (config , loaded_keys )
2693
2713
state_dict = load_state_dict (
@@ -3332,7 +3352,7 @@ def load_tp_checkpoint(folder, cls, config, return_numpy=False, convert_from_hf=
3332
3352
elif os .path .exists (model_path ):
3333
3353
state_dict = cls .convert_tensor_parallel (model_path , config )
3334
3354
elif os .path .exists (safe_model_path ):
3335
- with safe_open (safe_model_path , framework = "paddle " , device = "cpu" ) as f :
3355
+ with safe_open (safe_model_path , framework = "np " , device = "cpu" ) as f :
3336
3356
loaded_keys = f .keys ()
3337
3357
tp_actions = cls .get_tensor_parallel_convert_actions (config , loaded_keys )
3338
3358
state_dict = load_state_dict (
0 commit comments