Skip to content

Commit 263b2b7

Browse files
Moved vllm fq export code to separate files (#612)
## What does this PR do? **Type of change:** : Bug fix **Overview:** moved vLLM fakequant checkpoint export code to separate files: 1. for HF export -> modelopt.torch.export.plugins.vllm_fq_hf 2. for megatron export -> modelopt.torch.export.plugins.vllm_fq_megatron ## Usage Refer to [README.md](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/096ee13ea62bbb0ce0a4e4128c439651374d6235/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip) ## Testing - Tested HF approach by exporting bf16 model using QAT script and running vllm server, verified amax values match - Tested MCore approach by quantizing and exporting bf16 model using quantize.sh and export.sh script and running vllm server, verified amax values match - Tested using unit tests in `tests/gpu/torch/export/test_vllm_fq_hf_export.py` and `tests/gpu/torch/export/test_vllm_fq_megatron_export.py` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: NA - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: NA ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added dedicated export functions for vLLM fakequant checkpoint format, supporting both HuggingFace and Megatron Core models. * **Refactor** * Simplified export API by removing conditional export flags for cleaner, more predictable behavior. * Reorganized export functionality into focused plugin modules for improved maintainability. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Kinjal Patel <[email protected]>
1 parent 0a4f0a8 commit 263b2b7

File tree

10 files changed

+355
-310
lines changed

10 files changed

+355
-310
lines changed

examples/vllm_serve/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ lm_eval --model local-completions --tasks gsm8k --model_args model=<model_name>,
5757

5858
Overwrite the calibrated amax value with prepared values from either QAT/PTQ.
5959

60-
Step 1: export the model with bf16 weights and amax values.
60+
Step 1: export the model with bf16 weights and amax values. To export the model:
6161

62-
- For HF model set `export_bf16_weights_amax` to export the model with function `modelopt.torch.export.unified_export_hf.export_hf_checkpoint`.
63-
- For MCore model use `export_bf16_weights_amax` to export the model with function `modelopt.torch.export.unified_export_megatron.export_mcore_gpt_to_hf`.
62+
- For HF model use `modelopt.torch.export.export_hf_vllm_fq_checkpoint` function.
63+
- For MCore model use `modelopt.torch.export.export_mcore_gpt_to_hf_vllm_fq` function.
6464

6565
Step 2: configure <quant_amax.pth> from exported model using AMAX_FILE_PATH environment variable in step 1. For example:
6666

modelopt/torch/export/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .model_config import *
2020
from .model_config_export import *
2121
from .model_utils import *
22+
from .plugins import *
2223
from .transformer_engine import *
2324
from .unified_export_hf import *
2425
from .unified_export_megatron import *

modelopt/torch/export/plugins/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@
2121
from .megatron_importer import *
2222

2323
from .hf_spec_export import *
24+
from .vllm_fakequant_hf import *
25+
26+
with import_plugin("vllm_fakequant_megatron"):
27+
from .vllm_fakequant_megatron import *

modelopt/torch/export/plugins/vllm_fakequant.py

Lines changed: 0 additions & 125 deletions
This file was deleted.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Export HuggingFace model to vLLM fakequant checkpoint."""
16+
17+
from pathlib import Path
18+
19+
import torch
20+
import torch.nn as nn
21+
22+
from modelopt.torch.export.layer_utils import is_quantlinear
23+
from modelopt.torch.quantization.utils import get_quantizer_state_dict
24+
25+
__all__ = ["export_hf_vllm_fq_checkpoint"]
26+
27+
28+
def export_hf_vllm_fq_checkpoint(
29+
model: nn.Module,
30+
export_dir: Path | str,
31+
):
32+
"""Exports the torch model weights and amax values separately.
33+
34+
This function:
35+
1. Extracts amax values for calibration
36+
2. Deletes all quantizer parameters from state dict to store only weights in original dtype
37+
3. Saves the model weights
38+
39+
Args:
40+
model: The quantized model to export
41+
export_dir: Directory to save the amax values
42+
43+
"""
44+
export_dir = Path(export_dir)
45+
export_dir.mkdir(parents=True, exist_ok=True)
46+
47+
amax_dict = {
48+
name + "._amax": param["_amax"].detach().clone().cpu()
49+
for name, param in get_quantizer_state_dict(model).items()
50+
if "_amax" in param
51+
}
52+
53+
# remove quantizer from model
54+
for _, module in model.named_modules():
55+
if is_quantlinear(module):
56+
for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]:
57+
if hasattr(module, attr):
58+
delattr(module, attr)
59+
module.export()
60+
torch.save(amax_dict, f"{export_dir}/quant_amax.pth")
61+
# Save model
62+
model.save_pretrained(export_dir, state_dict=model.state_dict(), save_modelopt_state=False)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Export Megatron Core Model to HuggingFace vLLM fakequant checkpoint."""
16+
17+
import os
18+
import tempfile
19+
from pathlib import Path
20+
21+
import torch
22+
23+
from modelopt.torch.export.model_config import QUANTIZATION_NONE
24+
from modelopt.torch.export.unified_export_megatron import GPTModelExporter
25+
26+
__all__ = ["export_mcore_gpt_to_hf_vllm_fq"]
27+
28+
29+
def gather_mcore_vllm_fq_quantized_state_dict(
30+
model, state_dict: dict[str, torch.Tensor], save_directory: str | os.PathLike
31+
):
32+
"""Gather all quantized state dict from all ranks and save them to a file.
33+
34+
Args:
35+
state_dict: The state dictionary of the module.
36+
save_directory: The directory to save the quantized state dict.
37+
38+
Returns:
39+
The state dictionary of the module without quantized state.
40+
"""
41+
amax_state_dict = {
42+
k: v.detach().clone().cpu() for k, v in state_dict.items() if k.endswith("_amax")
43+
}
44+
45+
# Gather all amax dicts to rank 0
46+
world_size = torch.distributed.get_world_size()
47+
rank = torch.distributed.get_rank()
48+
49+
if rank == 0:
50+
# Rank 0 will collect all amax values
51+
all_amax_dicts = [None] * world_size
52+
torch.distributed.gather_object(amax_state_dict, all_amax_dicts, dst=0)
53+
54+
# Merge all amax dicts into one
55+
merged_amax_dict = {}
56+
for amax_dict in all_amax_dicts:
57+
if amax_dict is not None:
58+
merged_amax_dict.update(amax_dict)
59+
60+
print(f"Total amax entries from all ranks: {len(merged_amax_dict.keys())}")
61+
torch.save(merged_amax_dict, save_directory + "/quant_amax.pth")
62+
else:
63+
# Other ranks just send their amax values
64+
torch.distributed.gather_object(amax_state_dict, None, dst=0)
65+
66+
torch.distributed.barrier()
67+
68+
69+
class VllmFqGPTModelExporter(GPTModelExporter):
70+
"""VLLM fakequant GPTModel exporter."""
71+
72+
def save_pretrained(
73+
self,
74+
save_directory: str | os.PathLike,
75+
pretrained_model_name_or_path: str | os.PathLike | None = None,
76+
):
77+
os.makedirs(save_directory, exist_ok=True)
78+
gather_mcore_vllm_fq_quantized_state_dict(self.model, self.state_dict, save_directory)
79+
assert not (self.is_multimodal and pretrained_model_name_or_path is not None), (
80+
"Exporting weights in bf16 and amax values is not supported for multimodal models "
81+
"when pretrained_model_name_or_path is not None"
82+
)
83+
assert not self.export_extra_modules, (
84+
"Exporting extra modules is not supported for vLLM fakequant"
85+
)
86+
super().save_pretrained(save_directory, pretrained_model_name_or_path)
87+
88+
def _get_quantization_format(self, module: torch.nn.Module):
89+
return QUANTIZATION_NONE
90+
91+
92+
def export_mcore_gpt_to_hf_vllm_fq(
93+
model: torch.nn.Module,
94+
pretrained_model_name_or_path: str | os.PathLike | None = None,
95+
export_extra_modules: bool = False,
96+
dtype: torch.dtype = torch.bfloat16,
97+
export_dir: Path | str = tempfile.gettempdir(),
98+
moe_router_dtype: torch.dtype | None = None,
99+
):
100+
"""Export Megatron Core GPTModel to unified checkpoint and save to export_dir.
101+
102+
Args:
103+
model: The Megatron Core GPTModel instance.
104+
pretrained_model_name_or_path: Can be either: the *model id* of a
105+
pretrained model hosted inside a model repo on huggingface.co; or
106+
a *directory* containing model weights saved using
107+
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
108+
export_extra_modules: If True, export extra modules like medusa_heads or
109+
eagle_module. Otherwise, only export the base model.
110+
dtype: The weights data type to export the unquantized layers.
111+
export_dir: The target export path.
112+
"""
113+
exporter = VllmFqGPTModelExporter(
114+
model,
115+
pretrained_model_name_or_path,
116+
export_extra_modules=export_extra_modules,
117+
dtype=dtype,
118+
moe_router_dtype=moe_router_dtype,
119+
)
120+
exporter.save_pretrained(export_dir, pretrained_model_name_or_path)

modelopt/torch/export/unified_export_hf.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
)
6060
from .model_utils import get_language_model_from_vl, is_multimodal_model
6161
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
62-
from .plugins.vllm_fakequant import export_hf_vllm_fq_checkpoint
6362
from .quant_utils import (
6463
fuse_prequant_layernorm,
6564
fuse_prequant_to_linear,
@@ -559,7 +558,6 @@ def export_hf_checkpoint(
559558
dtype: torch.dtype | None = None,
560559
export_dir: Path | str = tempfile.gettempdir(),
561560
save_modelopt_state: bool = False,
562-
export_vllm_fq_weights_qstate: bool = False,
563561
):
564562
"""Exports the torch model to unified checkpoint and saves to export_dir.
565563
@@ -568,8 +566,6 @@ def export_hf_checkpoint(
568566
dtype: the weights data type to export the unquantized layers or the default model data type if None.
569567
export_dir: the target export path.
570568
save_modelopt_state: whether to save the modelopt state_dict.
571-
export_vllm_fq_weights_qstate: whether to export the weights and quantization state separately for vLLM
572-
fakequant serving.
573569
"""
574570
export_dir = Path(export_dir)
575571
export_dir.mkdir(parents=True, exist_ok=True)
@@ -583,11 +579,7 @@ def export_hf_checkpoint(
583579
return
584580

585581
try:
586-
if export_vllm_fq_weights_qstate:
587-
post_state_dict = export_hf_vllm_fq_checkpoint(model, export_dir)
588-
hf_quant_config = None
589-
else:
590-
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
582+
post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype)
591583

592584
if hf_quant_config is not None:
593585
# Save hf_quant_config.json for\ backward compatibility

0 commit comments

Comments
 (0)