Skip to content

Commit 53a4430

Browse files
committed
Restored FSDP2 sharded state_dict support
Signed-off-by: realAsma <[email protected]>
1 parent 77c48fe commit 53a4430

File tree

4 files changed

+129
-23
lines changed

4 files changed

+129
-23
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import os
18+
19+
from transformers import AutoModelForCausalLM
20+
21+
from modelopt.torch.quantization.plugins.transformers_trainer import (
22+
convert_sharded_model_to_hf_format,
23+
)
24+
25+
26+
def main():
27+
parser = argparse.ArgumentParser(description="Convert sharded checkpoint to HuggingFace format")
28+
parser.add_argument(
29+
"--hf_model_path", type=str, required=True, help="Path to the original HuggingFace model"
30+
)
31+
parser.add_argument(
32+
"--sharded_ckpt_path",
33+
type=str,
34+
required=True,
35+
help="Path to the sharded checkpoint directory",
36+
)
37+
parser.add_argument(
38+
"--output_path", type=str, default="", help="Output path to save the converted model"
39+
)
40+
41+
args = parser.parse_args()
42+
43+
model = AutoModelForCausalLM.from_pretrained(args.hf_model_path)
44+
if os.path.exists(os.path.join(args.sharded_ckpt_path, "pytorch_model_fsdp_0")):
45+
convert_sharded_model_to_hf_format(
46+
model, args.sharded_ckpt_path, "modelopt_state_train.pth", args.output_path
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
main()

examples/llm_qat/launch.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,4 +180,5 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \
180180

181181
start_time=$(date +%s)
182182
sh -c "$CMD"
183-
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
183+
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
184+
python convert_sharded_ckpt.py --hf_model_path $MODEL --sharded_ckpt_path $OUTPUT_DIR --output_path $OUTPUT_DIR

examples/llm_qat/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,5 +171,6 @@ def new_func(original_f_name, trainer, *args, **kwargs):
171171

172172
def get_metrics_with_perplexity(metrics):
173173
"""Add perplexity to the metrics."""
174-
metrics = {"perplexity": float(torch.exp(torch.tensor(metrics["eval_loss"]))), **metrics}
174+
if "eval_loss" in metrics:
175+
metrics["perplexity"] = float(torch.exp(torch.tensor(metrics["eval_loss"])))
175176
return metrics

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from dataclasses import dataclass, field
2222

2323
import torch
24+
import torch.distributed.checkpoint as dist_cp
25+
from accelerate.utils import save_fsdp_model
2426
from tqdm import tqdm
2527

2628
import modelopt.torch.opt as mto
@@ -114,6 +116,15 @@ def check_awq_smoothquant(quant_cfg):
114116
return is_awq_smoothquant
115117

116118

119+
def restore_modelopt_state_with_weights(model, modelopt_state_path):
120+
"""Restore the modelopt weights for fsdp2 models."""
121+
_modelopt_state = torch.load(modelopt_state_path, weights_only=False)
122+
modelopt_weights = _modelopt_state.pop("modelopt_state_weights", None)
123+
restore_from_modelopt_state(model, _modelopt_state)
124+
if modelopt_weights is not None:
125+
set_quantizer_state_dict(model, modelopt_weights)
126+
127+
117128
class QATTrainer(ModelOptHFTrainer):
118129
"""A drop-in replacement of HuggingFace's Trainer for quantization aware training with ModelOpt.
119130
@@ -165,10 +176,12 @@ def __init__(
165176
self._modelopt_state_path = os.path.join(self.args.output_dir, "modelopt_state_train.pth")
166177
if os.path.exists(self._modelopt_state_path):
167178
self._restore_modelopt_state_with_weights()
168-
print_rank_0("Restored modelopt state with weights.")
179+
elif is_quantized(self.model):
180+
self._save_modelopt_state_with_weights()
169181

170182
def _save_modelopt_state_with_weights(self):
171183
"""Save the modelopt weights for fsdp2 models."""
184+
print_rank_0(f"Saving modelopt state to {self._modelopt_state_path}")
172185
if torch.distributed.is_initialized():
173186
torch.distributed.barrier()
174187

@@ -179,18 +192,13 @@ def _save_modelopt_state_with_weights(self):
179192
for state in modelopt_state["modelopt_state_dict"]
180193
if "kd_loss" not in state and "export_student" not in state
181194
]
182-
modelopt_full_state = {
183-
"modelopt_state": modelopt_state,
184-
"modelopt_state_weights": get_quantizer_state_dict(self.model),
185-
}
186-
195+
modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(self.model)
187196
if self.args.should_save:
188-
torch.save(modelopt_full_state, self._modelopt_state_path)
197+
torch.save(modelopt_state, self._modelopt_state_path)
189198

190199
def _restore_modelopt_state_with_weights(self):
191-
modelopt_full_state = torch.load(self._modelopt_state_path, weights_only=False)
192-
restore_from_modelopt_state(self.model, modelopt_full_state["modelopt_state"])
193-
set_quantizer_state_dict(self.model, modelopt_full_state["modelopt_state_weights"])
200+
restore_modelopt_state_with_weights(self.model, self._modelopt_state_path)
201+
print_rank_0("Restored modelopt state with weights.")
194202

195203
def _quantize_model(self):
196204
"""Quantize the model. Restore the quantization state if it exists."""
@@ -219,7 +227,6 @@ def forward_loop(model):
219227
# Force garbage collection to free up memory
220228
gc.collect()
221229

222-
print_rank_0(f"Saving modelopt state to {self._modelopt_state_path}")
223230
self._save_modelopt_state_with_weights()
224231
torch.cuda.empty_cache()
225232

@@ -247,17 +254,29 @@ def evaluate(self, *args, **kwargs):
247254
self.model, _ = self.accelerator.prepare(self.model, dummy_optimizer)
248255
return super().evaluate(*args, **kwargs)
249256

250-
def save_model(self, *args, **kwargs):
257+
def save_model(
258+
self, output_dir: str | None = None, _internal_call: bool = False, *args, **kwargs
259+
):
251260
"""Save the quantized model."""
252-
if (
253-
(not self.is_in_train)
254-
and self.is_fsdp_enabled
255-
and self.accelerator.state.fsdp_plugin.state_dict_type != "FULL_STATE_DICT"
256-
):
257-
print_rank_0("Setting state_dict_type to FULL_STATE_DICT for final checkpoint save.")
258-
# TODO: test is this fix works for multi-node training
259-
self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
260-
return super().save_model(*args, **kwargs)
261+
dict_type = (
262+
str(self.accelerator.state.fsdp_plugin.state_dict_type) if self.is_fsdp_enabled else ""
263+
)
264+
if not _internal_call and self.is_fsdp_enabled and "SHARDED_STATE_DICT" in dict_type:
265+
# The default save_model in Trainer doesn't save checkpoint with SHARDED_STATE_DICT + FSDP.
266+
# We save the model manually at the end of the training in order to convert the last
267+
# checkpoint from distcp to HF compatible format.
268+
if output_dir is None:
269+
output_dir = self.args.output_dir
270+
save_fsdp_model(
271+
self.accelerator.state.fsdp_plugin,
272+
self.accelerator,
273+
self.model,
274+
output_dir,
275+
)
276+
self.processing_class.save_pretrained(output_dir)
277+
self.model.config.save_pretrained(output_dir)
278+
else:
279+
super().save_model(output_dir, _internal_call, *args, **kwargs)
261280

262281
def _patch_accelerate_for_fsdp2_fix(self):
263282
"""Fixes for accelerate prepare.
@@ -360,3 +379,37 @@ def save_model(
360379
return KDTrainer.save_model(
361380
self, output_dir, _internal_call, export_student, *args, **kwargs
362381
)
382+
383+
384+
def convert_sharded_model_to_hf_format(
385+
model, model_path, modelopt_state_name="modelopt_state.pth", output_path=None
386+
):
387+
"""Convert a sharded model to HF format.
388+
389+
Args:
390+
model: The original HF model.
391+
model_path: The path to the sharded model with pytorch_model_fsdp_0 directory.
392+
modelopt_state_name: The name of the modelopt state file. If not provided, the default name
393+
"modelopt_state.pth" will be used.
394+
output_path: The path to save the converted model. If not provided, the model will be saved
395+
to the same directory as the sharded model.
396+
"""
397+
if output_path is None:
398+
output_path = model_path
399+
os.makedirs(output_path, exist_ok=True)
400+
state_dict = {"model": model.state_dict()}
401+
sharded_model_path = os.path.join(model_path, "pytorch_model_fsdp_0")
402+
modelopt_state_path = os.path.join(model_path, modelopt_state_name)
403+
if not os.path.exists(sharded_model_path):
404+
print_rank_0(f"Sharded model path does not exist: {sharded_model_path}")
405+
return model
406+
dist_cp.load_state_dict(
407+
state_dict=state_dict,
408+
storage_reader=dist_cp.FileSystemReader(sharded_model_path),
409+
no_dist=True,
410+
)
411+
model.load_state_dict(state_dict["model"])
412+
restore_modelopt_state_with_weights(model, modelopt_state_path)
413+
mto.enable_huggingface_checkpointing()
414+
model.save_pretrained(output_path)
415+
return model

0 commit comments

Comments
 (0)