21
21
from dataclasses import dataclass , field
22
22
23
23
import torch
24
+ import torch .distributed .checkpoint as dist_cp
25
+ from accelerate .utils import save_fsdp_model
24
26
from tqdm import tqdm
25
27
26
28
import modelopt .torch .opt as mto
@@ -114,6 +116,15 @@ def check_awq_smoothquant(quant_cfg):
114
116
return is_awq_smoothquant
115
117
116
118
119
+ def restore_modelopt_state_with_weights (model , modelopt_state_path ):
120
+ """Restore the modelopt weights for fsdp2 models."""
121
+ _modelopt_state = torch .load (modelopt_state_path , weights_only = False )
122
+ modelopt_weights = _modelopt_state .pop ("modelopt_state_weights" , None )
123
+ restore_from_modelopt_state (model , _modelopt_state )
124
+ if modelopt_weights is not None :
125
+ set_quantizer_state_dict (model , modelopt_weights )
126
+
127
+
117
128
class QATTrainer (ModelOptHFTrainer ):
118
129
"""A drop-in replacement of HuggingFace's Trainer for quantization aware training with ModelOpt.
119
130
@@ -165,10 +176,12 @@ def __init__(
165
176
self ._modelopt_state_path = os .path .join (self .args .output_dir , "modelopt_state_train.pth" )
166
177
if os .path .exists (self ._modelopt_state_path ):
167
178
self ._restore_modelopt_state_with_weights ()
168
- print_rank_0 ("Restored modelopt state with weights." )
179
+ elif is_quantized (self .model ):
180
+ self ._save_modelopt_state_with_weights ()
169
181
170
182
def _save_modelopt_state_with_weights (self ):
171
183
"""Save the modelopt weights for fsdp2 models."""
184
+ print_rank_0 (f"Saving modelopt state to { self ._modelopt_state_path } " )
172
185
if torch .distributed .is_initialized ():
173
186
torch .distributed .barrier ()
174
187
@@ -179,18 +192,13 @@ def _save_modelopt_state_with_weights(self):
179
192
for state in modelopt_state ["modelopt_state_dict" ]
180
193
if "kd_loss" not in state and "export_student" not in state
181
194
]
182
- modelopt_full_state = {
183
- "modelopt_state" : modelopt_state ,
184
- "modelopt_state_weights" : get_quantizer_state_dict (self .model ),
185
- }
186
-
195
+ modelopt_state ["modelopt_state_weights" ] = get_quantizer_state_dict (self .model )
187
196
if self .args .should_save :
188
- torch .save (modelopt_full_state , self ._modelopt_state_path )
197
+ torch .save (modelopt_state , self ._modelopt_state_path )
189
198
190
199
def _restore_modelopt_state_with_weights (self ):
191
- modelopt_full_state = torch .load (self ._modelopt_state_path , weights_only = False )
192
- restore_from_modelopt_state (self .model , modelopt_full_state ["modelopt_state" ])
193
- set_quantizer_state_dict (self .model , modelopt_full_state ["modelopt_state_weights" ])
200
+ restore_modelopt_state_with_weights (self .model , self ._modelopt_state_path )
201
+ print_rank_0 ("Restored modelopt state with weights." )
194
202
195
203
def _quantize_model (self ):
196
204
"""Quantize the model. Restore the quantization state if it exists."""
@@ -219,7 +227,6 @@ def forward_loop(model):
219
227
# Force garbage collection to free up memory
220
228
gc .collect ()
221
229
222
- print_rank_0 (f"Saving modelopt state to { self ._modelopt_state_path } " )
223
230
self ._save_modelopt_state_with_weights ()
224
231
torch .cuda .empty_cache ()
225
232
@@ -247,17 +254,29 @@ def evaluate(self, *args, **kwargs):
247
254
self .model , _ = self .accelerator .prepare (self .model , dummy_optimizer )
248
255
return super ().evaluate (* args , ** kwargs )
249
256
250
- def save_model (self , * args , ** kwargs ):
257
+ def save_model (
258
+ self , output_dir : str | None = None , _internal_call : bool = False , * args , ** kwargs
259
+ ):
251
260
"""Save the quantized model."""
252
- if (
253
- (not self .is_in_train )
254
- and self .is_fsdp_enabled
255
- and self .accelerator .state .fsdp_plugin .state_dict_type != "FULL_STATE_DICT"
256
- ):
257
- print_rank_0 ("Setting state_dict_type to FULL_STATE_DICT for final checkpoint save." )
258
- # TODO: test is this fix works for multi-node training
259
- self .accelerator .state .fsdp_plugin .set_state_dict_type ("FULL_STATE_DICT" )
260
- return super ().save_model (* args , ** kwargs )
261
+ dict_type = (
262
+ str (self .accelerator .state .fsdp_plugin .state_dict_type ) if self .is_fsdp_enabled else ""
263
+ )
264
+ if not _internal_call and self .is_fsdp_enabled and "SHARDED_STATE_DICT" in dict_type :
265
+ # The default save_model in Trainer doesn't save checkpoint with SHARDED_STATE_DICT + FSDP.
266
+ # We save the model manually at the end of the training in order to convert the last
267
+ # checkpoint from distcp to HF compatible format.
268
+ if output_dir is None :
269
+ output_dir = self .args .output_dir
270
+ save_fsdp_model (
271
+ self .accelerator .state .fsdp_plugin ,
272
+ self .accelerator ,
273
+ self .model ,
274
+ output_dir ,
275
+ )
276
+ self .processing_class .save_pretrained (output_dir )
277
+ self .model .config .save_pretrained (output_dir )
278
+ else :
279
+ super ().save_model (output_dir , _internal_call , * args , ** kwargs )
261
280
262
281
def _patch_accelerate_for_fsdp2_fix (self ):
263
282
"""Fixes for accelerate prepare.
@@ -360,3 +379,37 @@ def save_model(
360
379
return KDTrainer .save_model (
361
380
self , output_dir , _internal_call , export_student , * args , ** kwargs
362
381
)
382
+
383
+
384
+ def convert_sharded_model_to_hf_format (
385
+ model , model_path , modelopt_state_name = "modelopt_state.pth" , output_path = None
386
+ ):
387
+ """Convert a sharded model to HF format.
388
+
389
+ Args:
390
+ model: The original HF model.
391
+ model_path: The path to the sharded model with pytorch_model_fsdp_0 directory.
392
+ modelopt_state_name: The name of the modelopt state file. If not provided, the default name
393
+ "modelopt_state.pth" will be used.
394
+ output_path: The path to save the converted model. If not provided, the model will be saved
395
+ to the same directory as the sharded model.
396
+ """
397
+ if output_path is None :
398
+ output_path = model_path
399
+ os .makedirs (output_path , exist_ok = True )
400
+ state_dict = {"model" : model .state_dict ()}
401
+ sharded_model_path = os .path .join (model_path , "pytorch_model_fsdp_0" )
402
+ modelopt_state_path = os .path .join (model_path , modelopt_state_name )
403
+ if not os .path .exists (sharded_model_path ):
404
+ print_rank_0 (f"Sharded model path does not exist: { sharded_model_path } " )
405
+ return model
406
+ dist_cp .load_state_dict (
407
+ state_dict = state_dict ,
408
+ storage_reader = dist_cp .FileSystemReader (sharded_model_path ),
409
+ no_dist = True ,
410
+ )
411
+ model .load_state_dict (state_dict ["model" ])
412
+ restore_modelopt_state_with_weights (model , modelopt_state_path )
413
+ mto .enable_huggingface_checkpointing ()
414
+ model .save_pretrained (output_path )
415
+ return model
0 commit comments