17
17
from colossalai .accelerator import get_accelerator
18
18
from colossalai .checkpoint_io import CheckpointIndexFile , CheckpointIO , GeneralCheckpointIO
19
19
from colossalai .checkpoint_io .utils import (
20
+ async_save_state_dict_shards ,
21
+ create_pinned_state_dict ,
20
22
get_model_base_filenames ,
21
23
get_optimizer_base_filenames ,
22
24
load_shard_state_dict ,
28
30
from colossalai .interface import ModelWrapper , OptimizerWrapper
29
31
from colossalai .logging import get_dist_logger
30
32
from colossalai .shardformer import ShardConfig , ShardFormer
33
+ from colossalai .utils .safetensors import load_flat
31
34
from colossalai .zero import GeminiDDP , GeminiOptimizer
32
35
from colossalai .zero .gemini .memory_tracer import MemStats
33
36
@@ -82,7 +85,15 @@ def save_unsharded_model(
82
85
state_dict = model .state_dict (only_rank_0 = True )
83
86
if self .coordinator .is_master ():
84
87
if use_async :
85
- super ().save_unsharded_model (model , checkpoint , gather_dtensor , use_safetensors , use_async )
88
+ from colossalai .utils .safetensors import save
89
+
90
+ if id (model ) not in self .pinned_state_dicts :
91
+ self .pinned_state_dicts [id (model )] = create_pinned_state_dict (state_dict )
92
+ for k , v in state_dict .items ():
93
+ self .pinned_state_dicts [id (model )][k ].copy_ (v )
94
+ state_dict [k ] = self .pinned_state_dicts [id (model )][k ]
95
+ writer = save (checkpoint , state_dict )
96
+ self .async_writers .append (writer )
86
97
else :
87
98
save_state_dict (state_dict , checkpoint , use_safetensors )
88
99
@@ -106,7 +117,19 @@ def save_unsharded_optimizer(
106
117
assert isinstance (optimizer , GeminiOptimizer ), "Please boost the optimizer before saving!"
107
118
state_dict = optimizer .state_dict ()
108
119
if self .coordinator .is_master ():
109
- save_state_dict (state_dict , checkpoint , use_safetensors = False )
120
+ if use_async :
121
+ from colossalai .utils .safetensors import _flatten_optim_state_dict , save
122
+
123
+ flatten_state_dict , metadata = _flatten_optim_state_dict (state_dict )
124
+ if id (optimizer ) not in self .pinned_state_dicts :
125
+ self .pinned_state_dicts [id (optimizer )] = create_pinned_state_dict (flatten_state_dict )
126
+ for k , v in flatten_state_dict .items ():
127
+ self .pinned_state_dicts [id (optimizer )][k ].copy_ (v )
128
+ flatten_state_dict [k ] = self .pinned_state_dicts [id (optimizer )][k ]
129
+ writer = save (checkpoint , flatten_state_dict , metadata )
130
+ self .async_writers .append (writer )
131
+ else :
132
+ save_state_dict (state_dict , checkpoint , use_safetensors = False )
110
133
111
134
def load_unsharded_optimizer (self , optimizer : GeminiOptimizer , checkpoint : str ):
112
135
"""
@@ -137,17 +160,29 @@ def save_sharded_model(
137
160
138
161
Path (checkpoint_path ).mkdir (parents = True , exist_ok = True )
139
162
140
- state_dict_shard = model .state_dict_shard (max_shard_size = max_shard_size , only_rank_0 = True )
163
+ if use_async and self .coordinator .is_master ():
164
+ if id (model ) not in self .pinned_state_dicts :
165
+ self .pinned_state_dicts [id (model )] = {}
166
+ pinned_state_dicts = self .pinned_state_dicts [id (model )]
167
+ else :
168
+ pinned_state_dicts = None
169
+ state_dict_shard = model .state_dict_shard (
170
+ max_shard_size = max_shard_size , only_rank_0 = True , pinned_state_dicts = pinned_state_dicts
171
+ )
141
172
weights_name , save_index_file = get_model_base_filenames (prefix , use_safetensors )
142
173
index_file = CheckpointIndexFile (checkpoint_path )
143
174
144
175
# Save shards of optimizer states.
145
176
is_master = self .coordinator .is_master ()
146
177
if use_async :
147
- super ().save_sharded_model (
148
- model , checkpoint_path , gather_dtensor , prefix , max_shard_size , use_safetensors , use_async
178
+ total_size , writers = async_save_state_dict_shards (
179
+ sharded_state_dict = state_dict_shard ,
180
+ checkpoint = checkpoint_path ,
181
+ index_file = index_file ,
182
+ base_filename = weights_name ,
183
+ is_master = is_master ,
149
184
)
150
-
185
+ self . async_writers . extend ( writers )
151
186
else :
152
187
total_size = save_state_dict_shards (
153
188
sharded_state_dict = state_dict_shard ,
@@ -158,17 +193,17 @@ def save_sharded_model(
158
193
use_safetensors = use_safetensors ,
159
194
)
160
195
161
- # only save the index file on the master rank
162
- if self .coordinator .is_master ():
163
- index_file .append_meta_data ("total_size" , total_size )
164
- index_file .write_index_file (save_index_file )
165
- save_config_file (model .unwrap (), checkpoint_path )
166
- self .logger .info (
167
- f"The model is split into checkpoint shards. "
168
- f"You can find where each parameters has been saved in the "
169
- f"index located at { save_index_file } ." ,
170
- ranks = [0 ],
171
- )
196
+ # only save the index file on the master rank
197
+ if self .coordinator .is_master ():
198
+ index_file .append_meta_data ("total_size" , total_size )
199
+ index_file .write_index_file (save_index_file )
200
+ save_config_file (model .unwrap (), checkpoint_path )
201
+ self .logger .info (
202
+ f"The model is split into checkpoint shards. "
203
+ f"You can find where each parameters has been saved in the "
204
+ f"index located at { save_index_file } ." ,
205
+ ranks = [0 ],
206
+ )
172
207
173
208
def load_sharded_model (
174
209
self , model : GeminiDDP , checkpoint_index_file : Path , strict : bool = False , use_safetensors : bool = False
@@ -201,7 +236,7 @@ def save_sharded_optimizer(
201
236
Path (checkpoint ).mkdir (parents = True , exist_ok = True )
202
237
203
238
# Preparing file paths and index file.
204
- states_name , save_index_file , param_group_file = get_optimizer_base_filenames (prefix )
239
+ states_name , save_index_file , param_group_file = get_optimizer_base_filenames (prefix , use_safetensors = use_async )
205
240
index_file = CheckpointIndexFile (checkpoint )
206
241
index_file .append_meta_data ("param_groups" , param_group_file )
207
242
@@ -212,17 +247,36 @@ def save_sharded_optimizer(
212
247
torch .save (param_groups , group_file_path )
213
248
214
249
# States are broken into shards within max_shard_size.
215
- state_dict_shard = optimizer .state_shard (prefix = prefix , max_shard_size = size_per_shard , only_rank_0 = True )
250
+ if use_async and self .coordinator .is_master ():
251
+ if id (optimizer ) not in self .pinned_state_dicts :
252
+ self .pinned_state_dicts [id (optimizer )] = {}
253
+ pinned_state_dicts = self .pinned_state_dicts [id (optimizer )]
254
+ else :
255
+ pinned_state_dicts = None
256
+ state_dict_shard = optimizer .state_shard (
257
+ prefix = prefix , max_shard_size = size_per_shard , only_rank_0 = True , pinned_state_dicts = pinned_state_dicts
258
+ )
216
259
217
260
# Save shards of optimizer states.
218
- total_size = save_state_dict_shards (
219
- sharded_state_dict = state_dict_shard ,
220
- checkpoint = checkpoint ,
221
- index_file = index_file ,
222
- base_filename = states_name ,
223
- is_master = self .coordinator .is_master (),
224
- use_safetensors = False ,
225
- )
261
+ if use_async :
262
+ total_size , writers = async_save_state_dict_shards (
263
+ sharded_state_dict = state_dict_shard ,
264
+ checkpoint = checkpoint ,
265
+ index_file = index_file ,
266
+ base_filename = states_name ,
267
+ is_master = self .coordinator .is_master (),
268
+ state_preprocess = True ,
269
+ )
270
+ self .async_writers .extend (writers )
271
+ else :
272
+ total_size = save_state_dict_shards (
273
+ sharded_state_dict = state_dict_shard ,
274
+ checkpoint = checkpoint ,
275
+ index_file = index_file ,
276
+ base_filename = states_name ,
277
+ is_master = self .coordinator .is_master (),
278
+ use_safetensors = False ,
279
+ )
226
280
227
281
# Wrap up index file. Only save it on master rank.
228
282
if self .coordinator .is_master ():
@@ -264,7 +318,10 @@ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_fi
264
318
# Load optimizer states from shard files under checkpoint path.
265
319
# For each file, only load the states managed by current process.
266
320
for shard_file in checkpoint_files :
267
- state_dict_shard = load_shard_state_dict (Path (shard_file ), use_safetensors = False )
321
+ if shard_file .endswith (".safetensors" ):
322
+ state_dict_shard = load_flat (shard_file )
323
+ else :
324
+ state_dict_shard = load_shard_state_dict (Path (shard_file ), use_safetensors = False )
268
325
optimizer .load_param_states (state_dict_shard )
269
326
del state_dict_shard
270
327
gc .collect ()
0 commit comments