Skip to content

Commit 339b7e6

Browse files
committed
minor
Signed-off-by: realAsma <[email protected]>
1 parent 032a4bf commit 339b7e6

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,10 @@ def forward_loop(model):
206206
for batch in tqdm(data_loader, desc="Calibrating"):
207207
batch = self._prepare_inputs(batch)
208208
# Important: We should forward pass using the unwrapped model
209-
# mtq.quantize will unwrap the model pass the unwrapped model to the forward_loop
209+
# mtq.quantize will unwrap the model & pass to the forward_loop
210210
self.model(**batch)
211211

212-
# TODO: Remove calibrate_with_adpaters - this should not be needed
212+
# TODO: Remove calibrate_with_adapters - this should not be needed
213213
with calibrate_with_adapters(self.model, self.args):
214214
print_rank_0("Quantizing the model...")
215215
mtq.quantize(self.model, self.quant_cfg, forward_loop) # type: ignore [arg-type]
@@ -252,7 +252,8 @@ def train(self, *args, **kwargs):
252252
"""Train the model."""
253253
outputs = super().train(*args, **kwargs)
254254
print_rank_0(
255-
"Training completed. Do not forget to save the final model using `trainer.save_model()`."
255+
"Training completed. Please save the final model using `Trainer.save_model()` "
256+
"to preserve ModelOpt states."
256257
)
257258
return outputs
258259

@@ -264,10 +265,17 @@ def save_model(self, *args, **kwargs):
264265
and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT"
265266
):
266267
print_rank_0("Setting state_dict_type to FULL_STATE_DICT for final checkpoint save.")
268+
original_type = self.accelerator.state.fsdp_plugin.state_dict_type
267269
self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
268270
outputs = super().save_model(*args, **kwargs)
269-
torch.distributed.barrier()
270-
print_rank_0("Saved serialized model")
271+
if torch.distributed.is_initialized():
272+
torch.distributed.barrier()
273+
if mto.ModeloptStateManager.is_converted(self.accelerator.unwrap_model(self.model)):
274+
print_rank_0(
275+
"Model saved. To restore, call mto.enable_huggingface_checkpointing() first before loading the "
276+
"model. See https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.opt.plugins.huggingface.html#modelopt.torch.opt.plugins.huggingface.enable_huggingface_checkpointing"
277+
)
278+
self.accelerator.state.fsdp_plugin.set_state_dict_type(original_type)
271279
else:
272280
outputs = super().save_model(*args, **kwargs)
273281
return outputs

0 commit comments

Comments
 (0)