1
1
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
2
2
import json
3
- import warnings
4
3
from dataclasses import asdict , dataclass
5
4
from typing import Dict , List , Optional , Tuple
6
5
12
11
except ModuleNotFoundError :
13
12
raise ModuleNotFoundError ("Please install tensornvme to use NVMeOptimizer" )
14
13
_TYPES_INV = {v : k for k , v in _TYPES .items ()}
14
+ import io
15
+
16
+ from torch .distributed .distributed_c10d import _pickler , _unpickler
17
+
18
+
19
+ def _object_to_tensor (obj , device ):
20
+ f = io .BytesIO ()
21
+ _pickler (f ).dump (obj )
22
+ byte_storage = torch .ByteStorage ._from_buffer (f .getvalue ()) # type: ignore[attr-defined]
23
+ # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
24
+ # Otherwise, it will casue 100X slowdown.
25
+ # See: https://github.com/pytorch/pytorch/issues/65696
26
+ byte_tensor = torch .ByteTensor (byte_storage ).to (device )
27
+ return byte_tensor
28
+
29
+
30
+ def _tensor_to_object (tensor , tensor_size ):
31
+ tensor = tensor .cpu ()
32
+ buf = tensor .numpy ().tobytes ()[:tensor_size ]
33
+ return _unpickler (io .BytesIO (buf )).load ()
15
34
16
35
17
36
@dataclass
@@ -28,49 +47,68 @@ class PreparedData:
28
47
offset : int
29
48
30
49
31
- def flatten_dict (nested_dict , parent_key = "" , separator = "^" ):
32
- """
33
- Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator.
34
-
35
- nested_dict: The input nested dictionary.
36
- parent_key: The parent key currently being processed.
37
- separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary."
38
- """
39
- items = []
40
- for k , v in nested_dict .items ():
41
- new_key = f"{ parent_key } { separator } { k } " if parent_key else str (k )
42
- if isinstance (v , dict ):
43
- items .extend (flatten_dict (v , new_key , separator ).items ())
44
- else :
45
- v = torch .tensor (v , dtype = torch .float16 ) if not isinstance (v , torch .Tensor ) else v
46
- items .append ((new_key , v ))
47
-
48
- return dict (items )
49
-
50
-
51
- def unflatten_dict (flattened_dict , separator = "^" ):
52
- """
53
- Restore a flattened dictionary back to a multi-level nested dictionary.
54
-
55
- flattened_dict: The flattened dictionary.
56
- separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary.
57
- """
58
- nested_dict = {}
59
- for key , value in flattened_dict .items ():
60
- keys = key .split (separator )
61
- try :
62
- keys [0 ] = int (keys [0 ])
63
- except ValueError :
64
- warnings .warn (f"{ key [0 ]} can't convert to integer" )
65
- d = nested_dict
66
- for part in keys [:- 1 ]:
67
- if part not in d :
68
- d [part ] = {}
69
- d = d [part ]
70
- assert isinstance (value , torch .Tensor )
71
- d [keys [- 1 ]] = value
72
-
73
- return nested_dict
50
+ def _cast_to_tensor (obj ):
51
+ if isinstance (obj , torch .Tensor ):
52
+ return obj
53
+ return _object_to_tensor (obj , "cpu" )
54
+
55
+
56
+ def _cast_to_object (tensor : torch .Tensor ):
57
+ return _tensor_to_object (tensor , tensor .numel () * tensor .element_size ())
58
+
59
+
60
+ def _flatten_optim_state_dict (state_dict : dict , seperator : str = "." ) -> Tuple [dict , Optional [dict ]]:
61
+ flat_dict = {}
62
+ non_tensor_keys = []
63
+ if "state" in state_dict :
64
+ # 3-level dict
65
+ states = state_dict ["state" ]
66
+ else :
67
+ # 2-level dict, usually for optimizer state dict shard
68
+ states = state_dict
69
+
70
+ for idx , d in states .items ():
71
+ for k , v in d .items ():
72
+ nested_key = f"state{ seperator } { idx } { seperator } { k } "
73
+ if not isinstance (v , torch .Tensor ):
74
+ non_tensor_keys .append (nested_key )
75
+ flat_dict [nested_key ] = _cast_to_tensor (v )
76
+ if "param_groups" in state_dict :
77
+ flat_dict ["param_groups" ] = _cast_to_tensor (state_dict ["param_groups" ])
78
+ non_tensor_keys .append ("param_groups" )
79
+ if len (non_tensor_keys ) > 0 :
80
+ metadata = {"non_tensor_keys" : non_tensor_keys }
81
+ else :
82
+ metadata = None
83
+ return flat_dict , metadata
84
+
85
+
86
+ def _unflatten_optim_state_dict (flat_dict : dict , metadata : Optional [dict ] = None , seperator : str = "." ):
87
+ state_dict = {}
88
+ if metadata is not None :
89
+ non_tensor_keys = json .loads (metadata ["non_tensor_keys" ])
90
+ else :
91
+ non_tensor_keys = []
92
+ flat_dict = {k : _cast_to_object (v ) if k in non_tensor_keys else v for k , v in flat_dict .items ()}
93
+ if "param_groups" in flat_dict :
94
+ # 3-level dict
95
+ state_dict ["param_groups" ] = flat_dict .pop ("param_groups" )
96
+ state_dict ["state" ] = {}
97
+ states = state_dict ["state" ]
98
+ else :
99
+ # 2-level dict, usually for optimizer state dict shard
100
+ states = state_dict
101
+
102
+ for k , v in flat_dict .items ():
103
+ parts = k .split (seperator )
104
+ assert len (parts ) == 3 and parts [0 ] == "state"
105
+ idx = int (parts [1 ])
106
+ key = parts [2 ]
107
+ if idx not in states :
108
+ states [idx ] = {}
109
+ states [idx ][key ] = v
110
+
111
+ return state_dict
74
112
75
113
76
114
def prepare (
@@ -124,10 +162,8 @@ def save(
124
162
f_writer .write_raw (tensor , tensor .data_ptr (), tensor .numel () * tensor .element_size (), f_writer .offset )
125
163
126
164
127
- def save_nested (
128
- f_writer : AsyncFileWriter , state_dict : Dict [str , torch .Tensor ], metadata : Optional [Dict [str , str ]] = None
129
- ) -> None :
130
- flatten_data = flatten_dict (state_dict )
165
+ def save_nested (f_writer : AsyncFileWriter , state_dict : Dict [str , torch .Tensor ]) -> None :
166
+ flatten_data , metadata = _flatten_optim_state_dict (state_dict )
131
167
save (f_writer , flatten_data , metadata )
132
168
133
169
@@ -154,10 +190,5 @@ def load_flat(checkpoint_path):
154
190
with safe_open (checkpoint_path , framework = "pt" ) as f :
155
191
metadata = f .metadata ()
156
192
state_dict_load = load_file (checkpoint_path )
157
- state_dict = unflatten_dict (state_dict_load )
158
- if metadata is None :
159
- return state_dict
160
- metadata = dict (map (lambda item : (item [0 ], json .loads (item [1 ])), metadata .items ()))
161
- combined_state_dict = {"state" : state_dict }
162
- combined_state_dict .update (metadata )
163
- return combined_state_dict
193
+ state_dict = _unflatten_optim_state_dict (state_dict_load , metadata )
194
+ return state_dict
0 commit comments