@@ -206,10 +206,10 @@ def forward_loop(model):
206
206
for batch in tqdm (data_loader , desc = "Calibrating" ):
207
207
batch = self ._prepare_inputs (batch )
208
208
# Important: We should forward pass using the unwrapped model
209
- # mtq.quantize will unwrap the model pass the unwrapped model to the forward_loop
209
+ # mtq.quantize will unwrap the model & pass to the forward_loop
210
210
self .model (** batch )
211
211
212
- # TODO: Remove calibrate_with_adpaters - this should not be needed
212
+ # TODO: Remove calibrate_with_adapters - this should not be needed
213
213
with calibrate_with_adapters (self .model , self .args ):
214
214
print_rank_0 ("Quantizing the model..." )
215
215
mtq .quantize (self .model , self .quant_cfg , forward_loop ) # type: ignore [arg-type]
@@ -252,7 +252,8 @@ def train(self, *args, **kwargs):
252
252
"""Train the model."""
253
253
outputs = super ().train (* args , ** kwargs )
254
254
print_rank_0 (
255
- "Training completed. Do not forget to save the final model using `trainer.save_model()`."
255
+ "Training completed. Please save the final model using `Trainer.save_model()` "
256
+ "to preserve ModelOpt states."
256
257
)
257
258
return outputs
258
259
@@ -264,10 +265,17 @@ def save_model(self, *args, **kwargs):
264
265
and self .accelerator .state .fsdp_plugin .state_dict_type != "FULL_STATE_DICT"
265
266
):
266
267
print_rank_0 ("Setting state_dict_type to FULL_STATE_DICT for final checkpoint save." )
268
+ original_type = self .accelerator .state .fsdp_plugin .state_dict_type
267
269
self .accelerator .state .fsdp_plugin .set_state_dict_type ("FULL_STATE_DICT" )
268
270
outputs = super ().save_model (* args , ** kwargs )
269
- torch .distributed .barrier ()
270
- print_rank_0 ("Saved serialized model" )
271
+ if torch .distributed .is_initialized ():
272
+ torch .distributed .barrier ()
273
+ if mto .ModeloptStateManager .is_converted (self .accelerator .unwrap_model (self .model )):
274
+ print_rank_0 (
275
+ "Model saved. To restore, call mto.enable_huggingface_checkpointing() first before loading the "
276
+ "model. See https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.opt.plugins.huggingface.html#modelopt.torch.opt.plugins.huggingface.enable_huggingface_checkpointing"
277
+ )
278
+ self .accelerator .state .fsdp_plugin .set_state_dict_type (original_type )
271
279
else :
272
280
outputs = super ().save_model (* args , ** kwargs )
273
281
return outputs
0 commit comments