Skip to content

Commit 479f729

Browse files
authored
[5620723](add to main) Save model in original dtype for QAT example (#546)
Pick change in #531 to main ## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** ? ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Fridah-nv <[email protected]>
1 parent 8188a01 commit 479f729

File tree

1 file changed

+53
-15
lines changed

1 file changed

+53
-15
lines changed

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""ModelOpt plugin for transformers Trainer."""
1717

1818
import gc
19+
import json
1920
import os
2021
import types
2122
from dataclasses import dataclass, field
@@ -168,6 +169,10 @@ def __init__(
168169
elif is_quantized(self.model):
169170
self._save_modelopt_state_with_weights()
170171

172+
self._original_dtype = getattr(
173+
getattr(self.model, "config", None), "dtype", None
174+
) or getattr(getattr(self.model, "config", None), "torch_dtype", None)
175+
171176
def _save_modelopt_state_with_weights(self):
172177
"""Save the modelopt weights for fsdp2 models."""
173178
if torch.distributed.is_initialized():
@@ -256,23 +261,30 @@ def train(self, *args, **kwargs):
256261

257262
def save_model(self, *args, **kwargs):
258263
"""Save the quantized model."""
259-
if (
260-
(not self.is_in_train)
261-
and self.is_fsdp_enabled
262-
and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT"
263-
):
264-
print_rank_0("Setting state_dict_type to FULL_STATE_DICT for final checkpoint save.")
265-
original_type = self.accelerator.state.fsdp_plugin.state_dict_type
266-
self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
267-
outputs = super().save_model(*args, **kwargs)
268-
if torch.distributed.is_initialized():
269-
torch.distributed.barrier()
270-
if mto.ModeloptStateManager.is_converted(self.accelerator.unwrap_model(self.model)):
264+
if not self.is_in_train:
265+
if (
266+
self.is_fsdp_enabled
267+
and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT"
268+
):
271269
print_rank_0(
272-
"Model saved. To restore, call mto.enable_huggingface_checkpointing() first before loading the "
273-
"model. See https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.opt.plugins.huggingface.html#modelopt.torch.opt.plugins.huggingface.enable_huggingface_checkpointing"
270+
"Setting state_dict_type to FULL_STATE_DICT for final checkpoint save."
274271
)
275-
self.accelerator.state.fsdp_plugin.set_state_dict_type(original_type)
272+
original_type = self.accelerator.state.fsdp_plugin.state_dict_type
273+
self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
274+
outputs = super().save_model(*args, **kwargs)
275+
if torch.distributed.is_initialized():
276+
torch.distributed.barrier()
277+
if mto.ModeloptStateManager.is_converted(self.accelerator.unwrap_model(self.model)):
278+
print_rank_0(
279+
"Model saved. To restore, call mto.enable_huggingface_checkpointing() first before loading the "
280+
"model. See https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.opt.plugins.huggingface.html#modelopt.torch.opt.plugins.huggingface.enable_huggingface_checkpointing"
281+
)
282+
self.accelerator.state.fsdp_plugin.set_state_dict_type(original_type)
283+
if self.args.should_save:
284+
out_dir = args[0]
285+
# FSDP may upcast parameter dtype to float32 during mixed-precision training,
286+
# we convert it back to original dtype by updating `torch-dtype` in `config.json`
287+
self._update_config_json_dtype(out_dir, str(self._original_dtype).split(".")[1])
276288
else:
277289
outputs = super().save_model(*args, **kwargs)
278290
return outputs
@@ -296,6 +308,32 @@ def _load_best_model(self, *args, **kwargs):
296308
else:
297309
super()._load_best_model(*args, **kwargs)
298310

311+
def _update_config_json_dtype(self, output_dir: str, dtype_str: str | None) -> None:
312+
"""Rewrite <output_dir>/config.json 'dtype' (preferred) or 'torch_dtype' to dtype_str."""
313+
cfg_path = os.path.join(output_dir, "config.json")
314+
if not os.path.isfile(cfg_path):
315+
print_rank_0(f"[warn] config.json not found under {output_dir}; skip dtype rewrite.")
316+
return
317+
try:
318+
with open(cfg_path, encoding="utf-8") as f:
319+
data = json.load(f)
320+
# Prefer 'dtype', else fall back to 'torch_dtype'
321+
key_to_update = (
322+
"dtype" if "dtype" in data else ("torch_dtype" if "torch_dtype" in data else None)
323+
)
324+
if key_to_update is None:
325+
print_rank_0(
326+
"[warn] Neither 'dtype' nor 'torch_dtype' present in config.json; skip dtype rewrite."
327+
)
328+
return
329+
if data.get(key_to_update) != dtype_str:
330+
data[key_to_update] = dtype_str
331+
with open(cfg_path, "w", encoding="utf-8") as f:
332+
json.dump(data, f, ensure_ascii=False, indent=2)
333+
print_rank_0(f'Updated config.json: {key_to_update} -> "{dtype_str}"')
334+
except Exception as e:
335+
print_rank_0(f"[warn] Failed to update dtype in config.json: {e}")
336+
299337
def _patch_accelerate_for_fsdp2_fix(self):
300338
"""Fixes for accelerate prepare.
301339

0 commit comments

Comments
 (0)