21
21
import json
22
22
import os
23
23
import re
24
- import sys
25
24
import tempfile
26
25
import warnings
27
26
from contextlib import contextmanager
30
29
from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , Union
31
30
32
31
import aistudio_sdk
33
- import ml_dtypes
34
32
import numpy as np
35
33
import paddle
36
34
import paddle .nn as nn
@@ -128,14 +126,9 @@ def unwrap_optimizer(optimizer, optimizer_instances=()):
128
126
129
127
130
128
if is_safetensors_available ():
131
- from safetensors .numpy import save_file as safe_save_file
132
-
133
- from ..utils .safetensors import fast_load_file as safe_load_file
134
-
135
- if sys .platform .startswith ("win" ):
136
- from safetensors import safe_open
137
- else :
138
- from ..utils .safetensors import fast_safe_open as safe_open
129
+ from safetensors import safe_open
130
+ from safetensors .paddle import load_file as safe_load_file
131
+ from safetensors .paddle import save_file as safe_save_file
139
132
140
133
141
134
def prune_linear_layer (layer : nn .Linear , index : paddle .Tensor , dim : int = 0 ) -> nn .Linear :
@@ -402,7 +395,7 @@ def _transpose_hf_weight(key, weight):
402
395
403
396
part_state_dict = {}
404
397
scale_dict = {}
405
- with safe_open (checkpoint_file , framework = "np " ) as f :
398
+ with safe_open (checkpoint_file , framework = "paddle " ) as f :
406
399
for key in keys :
407
400
# 1. non-merge ckpt loading dont have filter key.
408
401
# 2. merge ckpt will skip quant scale by `fliter_dict_keys`
@@ -422,8 +415,7 @@ def _transpose_hf_weight(key, weight):
422
415
and key .split (".weight" )[0 ] in quantization_linear_list
423
416
and not key .endswith ("_scale" )
424
417
):
425
- # numpy.array -> paddle.tensor
426
- weight = paddle .Tensor .__call__ (py_safe_slice_ [:], zero_copy = True )
418
+ weight = py_safe_slice_ [:]
427
419
weight = _transpose_hf_weight (key , weight )
428
420
key_name = key .split (".weight" )[0 ]
429
421
quant_key_name = key_name + ".quant_weight"
@@ -458,19 +450,17 @@ def _transpose_hf_weight(key, weight):
458
450
is_column = not is_column
459
451
tp_fn = partial (tp_fn .func , * tp_fn .args , ** {** tp_fn .keywords , "is_column" : is_column })
460
452
if len (py_safe_slice_ .shape ) == 0 :
461
- weight = tp_fn (py_safe_slice_ . get () )
453
+ weight = tp_fn (py_safe_slice_ [:] )
462
454
else :
463
455
weight = tp_fn (py_safe_slice_ )
464
456
else :
465
- if len (py_safe_slice_ .shape ) == 0 :
466
- weight = py_safe_slice_ .get ()
467
- else :
468
- weight = py_safe_slice_ [:]
457
+ weight = py_safe_slice_ [:]
458
+
469
459
if not return_numpy and device == "expected" :
470
- with device_guard ():
471
- weight = paddle .Tensor .__call__ (weight , zero_copy = True )
472
460
weight = weight ._copy_to (paddle .framework ._current_expected_place (), False )
473
461
weight = _transpose_hf_weight (key , weight )
462
+ if return_numpy :
463
+ weight = weight .numpy ()
474
464
part_state_dict [key ] = weight
475
465
476
466
for key in keys :
@@ -481,9 +471,9 @@ def _transpose_hf_weight(key, weight):
481
471
):
482
472
scale = f .get_tensor (key )
483
473
if not return_numpy and device == "expected" :
484
- with device_guard ():
485
- scale = paddle .Tensor .__call__ (scale , zero_copy = True )
486
474
scale = scale ._copy_to (paddle .framework ._current_expected_place (), False )
475
+ if return_numpy :
476
+ scale = scale .numpy ()
487
477
scale_dict [key ] = scale
488
478
return part_state_dict , scale_dict
489
479
@@ -511,26 +501,34 @@ def load_state_dict(
511
501
if (
512
502
checkpoint_file .endswith (".safetensors" ) or re .search (r"\.safetensors_shard_\d{4}$" , checkpoint_file )
513
503
) and is_safetensors_available ():
514
- # Check format of the archive
515
- with safe_open (checkpoint_file , framework = "np" ) as f :
516
- metadata = {"format" : "np" }
517
-
518
- if metadata .get ("format" , "np" ) not in ["pd" , "np" ]:
519
- raise OSError (
520
- f"The safetensors archive passed at { checkpoint_file } does not contain the valid metadata. Make sure "
521
- "you save your model with the `save_pretrained` method."
522
- )
523
- if metadata .get ("format" , "np" ) == "pd" :
524
- raise ValueError ("Currently unsupport paddle weights file, use numpy instead." )
525
- if metadata .get ("format" , "np" ) == "np" :
526
- thread_num = int (os .environ .get ("LOAD_STATE_DICT_THREAD_NUM" , "1" ))
527
- if thread_num > 1 :
528
- logger .info (f"Set loading state_dict thread num to { thread_num } " )
529
- state_dict , scale_dict = {}, {}
530
- if thread_num <= 1 :
531
- with safe_open (checkpoint_file , framework = "np" ) as f :
532
- state_dict , scale_dict = _load_part_state_dict (
533
- list (f .keys ()),
504
+ thread_num = int (os .environ .get ("LOAD_STATE_DICT_THREAD_NUM" , "1" ))
505
+ if thread_num > 1 :
506
+ logger .info (f"Set loading state_dict thread num to { thread_num } " )
507
+ state_dict , scale_dict = {}, {}
508
+ if thread_num <= 1 :
509
+ with safe_open (checkpoint_file , framework = "paddle" ) as f :
510
+ state_dict , scale_dict = _load_part_state_dict (
511
+ list (f .keys ()),
512
+ checkpoint_file ,
513
+ tensor_parallel_split_mapping ,
514
+ fliter_dict_keys ,
515
+ device ,
516
+ quantization_linear_list ,
517
+ quantization_config ,
518
+ dtype ,
519
+ return_numpy ,
520
+ convert_from_hf ,
521
+ transpose_weight_keys ,
522
+ )
523
+ else :
524
+ # Load state dict in multi-thread to speed up loading
525
+ with safe_open (checkpoint_file , framework = "paddle" ) as f :
526
+ keys_groups = _split_keys_evenly (list (f .keys ()), thread_num )
527
+ with concurrent .futures .ThreadPoolExecutor (max_workers = thread_num ) as executor :
528
+ future_to_key = {
529
+ executor .submit (
530
+ _load_part_state_dict ,
531
+ keys ,
534
532
checkpoint_file ,
535
533
tensor_parallel_split_mapping ,
536
534
fliter_dict_keys ,
@@ -541,54 +539,41 @@ def load_state_dict(
541
539
return_numpy ,
542
540
convert_from_hf ,
543
541
transpose_weight_keys ,
542
+ ): keys
543
+ for keys in keys_groups
544
+ }
545
+ for future in concurrent .futures .as_completed (future_to_key ):
546
+ res_state_dict , res_scale_dict = future .result ()
547
+ state_dict .update (res_state_dict )
548
+ scale_dict .update (res_scale_dict )
549
+
550
+ if not return_numpy :
551
+ if device == "pin_memory" :
552
+ for k in list (state_dict .keys ()):
553
+ pd_tensor = state_dict .pop (k )
554
+ state_dict [k ] = (
555
+ pd_tensor
556
+ if pd_tensor .place == paddle .CUDAPinnedPlace ()
557
+ else pd_tensor .to (paddle .CUDAPinnedPlace ())
544
558
)
545
- else :
546
- # Load state dict in multi-thread to speed up loading
547
- with safe_open (checkpoint_file , framework = "np" ) as f :
548
- keys_groups = _split_keys_evenly (list (f .keys ()), thread_num )
549
- with concurrent .futures .ThreadPoolExecutor (max_workers = thread_num ) as executor :
550
- future_to_key = {
551
- executor .submit (
552
- _load_part_state_dict ,
553
- keys ,
554
- checkpoint_file ,
555
- tensor_parallel_split_mapping ,
556
- fliter_dict_keys ,
557
- device ,
558
- quantization_linear_list ,
559
- quantization_config ,
560
- dtype ,
561
- return_numpy ,
562
- convert_from_hf ,
563
- transpose_weight_keys ,
564
- ): keys
565
- for keys in keys_groups
566
- }
567
- for future in concurrent .futures .as_completed (future_to_key ):
568
- res_state_dict , res_scale_dict = future .result ()
569
- state_dict .update (res_state_dict )
570
- scale_dict .update (res_scale_dict )
571
-
572
- if not return_numpy :
573
- if device == "cpu" :
574
- with device_guard ():
575
- for k in list (state_dict .keys ()):
576
- state_dict [k ] = paddle .Tensor .__call__ (state_dict .pop (k ), zero_copy = True )
577
- elif device == "pin_memory" :
578
- for k in list (state_dict .keys ()):
579
- state_dict [k ] = paddle .to_tensor (state_dict .pop (k ), place = paddle .CUDAPinnedPlace ())
559
+ else :
560
+ for k in list (state_dict .keys ()):
561
+ state_dict [k ] = state_dict .pop (k ).numpy ()
580
562
581
- if len (scale_dict ) != 0 :
582
- if ckpt_quant_stage == "O0" :
583
- raise ValueError ('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"' )
584
- state_dict = dequant_unified_optimizer (state_dict , ckpt_quant_stage , scale_dict , use_pd = True )
563
+ if len (scale_dict ) != 0 :
564
+ if ckpt_quant_stage == "O0" :
565
+ raise ValueError ('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"' )
566
+ state_dict = dequant_unified_optimizer (state_dict , ckpt_quant_stage , scale_dict , use_pd = True )
585
567
586
- return state_dict
568
+ return state_dict
587
569
588
570
# load from hf but not safetensors checkpoint
589
571
if convert_from_hf :
590
572
state_dict = load_torch (checkpoint_file )
591
573
state_dict = ConversionMixin .convert_transpose_selected_weights (state_dict , transpose_weight_keys )
574
+ if return_numpy :
575
+ for k in list (state_dict .keys ()):
576
+ state_dict [k ] = state_dict .pop (k ).numpy ()
592
577
return state_dict
593
578
594
579
state_dict = paddleformers_load (checkpoint_file , map_location = "cpu" )
@@ -599,10 +584,8 @@ def prepare_safe_save_state_dict(state_dict, save_to_hf=False):
599
584
for k in list (state_dict .keys ()):
600
585
if isinstance (state_dict [k ], paddle .Tensor ):
601
586
if state_dict [k ].dtype == paddle .bfloat16 :
602
- state_dict [k ] = state_dict .pop (k ).astype ("float32" ).cpu ().numpy ().astype (ml_dtypes .bfloat16 )
603
- else :
604
- state_dict [k ] = state_dict .pop (k ).cpu ().numpy ()
605
- metadata = {"format" : "pt" } if save_to_hf else {"format" : "np" }
587
+ state_dict [k ] = state_dict .pop (k ).contiguous ().astype (paddle .bfloat16 )
588
+ metadata = {"format" : "pt" } if save_to_hf else {"format" : "paddle" }
606
589
return state_dict , metadata
607
590
608
591
@@ -2051,7 +2034,6 @@ def get_file_path(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME, v
2051
2034
f"Error no files { filenames } found in repo { pretrained_model_name_or_path } ."
2052
2035
)
2053
2036
elif "pytorch_model.bin" in str (resolved_archive_file ):
2054
-
2055
2037
if download_hub == DownloadSource .AISTUDIO and not convert_from_hf :
2056
2038
raise ValueError (
2057
2039
f"Download pytorch weight in "
@@ -2632,9 +2614,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
2632
2614
logger .warning ("`load_state_as_np` is deprecated, please delete it!" )
2633
2615
2634
2616
model_kwargs = kwargs
2635
-
2636
2617
if convert_from_hf is None and download_hub == DownloadSource .MODELSCOPE :
2637
-
2638
2618
logger .warning (
2639
2619
"If you are attempting to load weights from ModelScope Hub and want to disable the default behavior of considering torch weights,"
2640
2620
" you can set ·convert_from_hf=False·. By default, `convert_from_hf` is set to `True`. "
@@ -2707,7 +2687,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
2707
2687
if config .tensor_parallel_degree > 1 and resolved_archive_file .endswith ("model_state.pdparams" ):
2708
2688
state_dict = cls .convert_tensor_parallel (resolved_archive_file , config )
2709
2689
elif config .tensor_parallel_degree > 1 and resolved_archive_file .endswith ("model.safetensors" ):
2710
- with safe_open (resolved_archive_file , framework = "np " , device = "cpu" ) as f :
2690
+ with safe_open (resolved_archive_file , framework = "paddle " , device = "cpu" ) as f :
2711
2691
loaded_keys = f .keys ()
2712
2692
tp_actions = cls .get_tensor_parallel_convert_actions (config , loaded_keys )
2713
2693
state_dict = load_state_dict (
@@ -3352,7 +3332,7 @@ def load_tp_checkpoint(folder, cls, config, return_numpy=False, convert_from_hf=
3352
3332
elif os .path .exists (model_path ):
3353
3333
state_dict = cls .convert_tensor_parallel (model_path , config )
3354
3334
elif os .path .exists (safe_model_path ):
3355
- with safe_open (safe_model_path , framework = "np " , device = "cpu" ) as f :
3335
+ with safe_open (safe_model_path , framework = "paddle " , device = "cpu" ) as f :
3356
3336
loaded_keys = f .keys ()
3357
3337
tp_actions = cls .get_tensor_parallel_convert_actions (config , loaded_keys )
3358
3338
state_dict = load_state_dict (
0 commit comments