Skip to content

Commit 6d69e05

Browse files
committed
Support fakequant serve with latest vllm and generalize calibration logic
1 parent 26c203a commit 6d69e05

File tree

4 files changed

+567
-0
lines changed

4 files changed

+567
-0
lines changed

examples/vllm_serve/Dockerfile

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
FROM vllm/vllm-openai:v0.10.2
2+
3+
# Set environment variables
4+
ENV PIP_NO_CACHE_DIR=off \
5+
PIP_CONSTRAINT=
6+
7+
WORKDIR /workspace
8+
9+
# Install system dependencies needed for modelopt
10+
RUN apt-get update && apt-get install -y \
11+
git \
12+
build-essential \
13+
&& rm -rf /var/lib/apt/lists/*
14+
15+
# Copy the entire TensorRT-Model-Optimizer source code
16+
COPY . TensorRT-Model-Optimizer
17+
18+
# Remove .git directory to reduce image size
19+
RUN rm -rf TensorRT-Model-Optimizer/.git
20+
21+
# Install modelopt from local source with all dependencies
22+
RUN cd TensorRT-Model-Optimizer && \
23+
pip install -e ".[all,dev-test]"
24+
25+
# Llama4 requires this
26+
RUN pip install flash-attn==2.7.4.post1
27+
28+
# Pre-compile CUDA extensions to avoid compilation time during runtime
29+
RUN python3 -c "import modelopt.torch.quantization.extensions as ext; ext.precompile()" || true
30+
31+
# Install requirements from examples (excluding windows examples)
32+
RUN find TensorRT-Model-Optimizer/examples -name "requirements.txt" | grep -v "windows" | while read req_file; do \
33+
echo "Installing from $req_file"; \
34+
pip install -r "$req_file" || echo "Warning: Failed to install from $req_file"; \
35+
done
36+
37+
# Allow users to run without root
38+
RUN chmod -R 777 /workspace
39+
40+
# Override the ENTRYPOINT from the base image to allow flexible usage
41+
ENTRYPOINT []
42+
43+
# Set the default command
44+
CMD ["/bin/bash"]

examples/vllm_serve/README.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Serve fakequant models with vLLM
2+
3+
This is a simple example to demonstrate calibrating and serving ModelOpt fakequant models in vLLM.
4+
5+
Compared with realquant, fakequant is 2-5x slower, but doesn't require dedicated kernel support and facilitates research.
6+
7+
This example is tested with vllm 0.9.0 and 0.11.2
8+
9+
## Prepare environment
10+
11+
Follow the following instruction to build a docker environment, or install vllm with pip.
12+
13+
```bash
14+
docker build -f examples/vllm_serve/Dockerfile -t vllm-modelopt .
15+
```
16+
17+
## Calibrate and serve fake quant model in vLLM
18+
19+
Step 1: Modify `quant_config` in `vllm_serve_fake_quant.py` for the desired quantization format
20+
21+
Step 2: Run the following command, with all supported flag as `vllm serve`:
22+
23+
```bash
24+
python vllm_serve_fakequant.py <model_path> -tp 8 --host 0.0.0.0 --port 8000
25+
```
26+
27+
Step 3: test the API server with curl:
28+
29+
```bash
30+
curl -X POST "http://127.0.0.1:8000/v1/chat/completions" -H "Content-Type: application/json" -d '{
31+
"model": "<model_path>",
32+
"messages": [
33+
{"role": "user", "content": "Hi, what is your name"}
34+
],
35+
"max_tokens": 8
36+
}'
37+
38+
```
39+
40+
Step 4 (Optional): using lm_eval to run evaluation
41+
42+
```bash
43+
lm_eval --model local-completions --tasks gsm8k --model_args model=<model_name>,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=1,max_retries=3,tokenized_requests=False,batch_size=128,tokenizer_backend=None
44+
```
45+
46+
## Load QAT/PTQ model and serve in vLLM (WIP)
47+
48+
Overwrite the calibrated amax value with prepared values from either PTQ/QAT. This is only tested for Llama3.1
49+
50+
Step 1: convert amax to merged amax, using llama3.1 as an example:
51+
52+
```bash
53+
python convert_amax_hf2vllm.py -i <amax.pth> -o <vllm_amax.pth>
54+
```
55+
56+
Step 2: add `<vllm_amax.pth>` to `quant_config` in `vllm_serve_fakequant.py`
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
#!/usr/bin/env python3
2+
3+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import argparse
19+
import os
20+
import re
21+
from collections import defaultdict
22+
23+
import torch
24+
25+
26+
def convert_amax_hf2vllm(
27+
hf_state_dict: dict[str, torch.Tensor],
28+
) -> dict[str, torch.Tensor]:
29+
"""
30+
Convert amax values from HuggingFace format to vLLM format.
31+
32+
This function merges:
33+
- q_proj, k_proj, v_proj amax values into qkv_proj (taking max)
34+
- gate_proj, up_proj amax values into gate_up_proj (taking max)
35+
36+
Args:
37+
hf_state_dict: HuggingFace state dict containing amax values
38+
39+
Returns:
40+
vLLM format state dict with merged amax values
41+
"""
42+
vllm_state_dict = {}
43+
44+
# Group keys by their base pattern (without the specific projection name)
45+
merge_groups = defaultdict(list)
46+
47+
for key, value in hf_state_dict.items():
48+
if "_amax" not in key:
49+
# Copy non-amax keys as-is
50+
vllm_state_dict[key] = value
51+
continue
52+
53+
# Check if this is a q/k/v projection that needs merging
54+
qkv_match = re.search(r"(.*\.)([qkv])_proj(\..+_amax)$", key)
55+
if qkv_match:
56+
base_pattern = qkv_match.group(1) + "qkv_proj" + qkv_match.group(3)
57+
merge_groups[base_pattern].append((key, value))
58+
continue
59+
60+
# Check if this is a gate/up projection that needs merging
61+
gate_up_match = re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key)
62+
if gate_up_match:
63+
base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3)
64+
merge_groups[base_pattern].append((key, value))
65+
continue
66+
67+
# Copy other amax keys as-is (like o_proj, down_proj)
68+
vllm_state_dict[key] = value
69+
70+
# Merge grouped amax values by taking the maximum
71+
for merged_key, key_value_pairs in merge_groups.items():
72+
if len(key_value_pairs) > 1:
73+
# Take the maximum across all values for this merged key
74+
values = [value for _, value in key_value_pairs]
75+
merged_value = torch.stack(values).max(dim=0)[0]
76+
vllm_state_dict[merged_key] = merged_value
77+
print(f"Merged {len(key_value_pairs)} keys into {merged_key}")
78+
for orig_key, _ in key_value_pairs:
79+
print(f" - {orig_key}")
80+
else:
81+
# Single key, just rename it
82+
_, value = key_value_pairs[0]
83+
vllm_state_dict[merged_key] = value
84+
85+
return vllm_state_dict
86+
87+
88+
def test_conversion():
89+
"""Test the conversion logic with sample keys"""
90+
import torch
91+
92+
# Create sample HF state dict
93+
sample_hf_keys = [
94+
"model.layers.0.self_attn.q_proj.input_quantizer._amax",
95+
"model.layers.0.self_attn.k_proj.input_quantizer._amax",
96+
"model.layers.0.self_attn.v_proj.input_quantizer._amax",
97+
"model.layers.0.self_attn.q_proj.weight_quantizer._amax",
98+
"model.layers.0.self_attn.k_proj.weight_quantizer._amax",
99+
"model.layers.0.self_attn.v_proj.weight_quantizer._amax",
100+
"model.layers.0.self_attn.o_proj.input_quantizer._amax",
101+
"model.layers.0.self_attn.o_proj.weight_quantizer._amax",
102+
"model.layers.0.mlp.gate_proj.input_quantizer._amax",
103+
"model.layers.0.mlp.up_proj.input_quantizer._amax",
104+
"model.layers.0.mlp.gate_proj.weight_quantizer._amax",
105+
"model.layers.0.mlp.up_proj.weight_quantizer._amax",
106+
"model.layers.0.mlp.down_proj.input_quantizer._amax",
107+
"model.layers.0.mlp.down_proj.weight_quantizer._amax",
108+
]
109+
110+
hf_state_dict = {}
111+
for key in sample_hf_keys:
112+
hf_state_dict[key] = torch.tensor([1.0, 2.0, 3.0]) # Sample values
113+
114+
print("Testing conversion with sample keys...")
115+
print(f"Input keys: {len(sample_hf_keys)}")
116+
117+
vllm_state_dict = convert_amax_hf2vllm(hf_state_dict)
118+
vllm_amax_keys = [k for k in vllm_state_dict if "_amax" in k]
119+
120+
print(f"Output keys: {len(vllm_amax_keys)}")
121+
print("\nExpected vLLM keys:")
122+
expected_keys = [
123+
"model.layers.0.self_attn.qkv_proj.input_quantizer._amax",
124+
"model.layers.0.self_attn.qkv_proj.weight_quantizer._amax",
125+
"model.layers.0.self_attn.o_proj.input_quantizer._amax",
126+
"model.layers.0.self_attn.o_proj.weight_quantizer._amax",
127+
"model.layers.0.mlp.gate_up_proj.input_quantizer._amax",
128+
"model.layers.0.mlp.gate_up_proj.weight_quantizer._amax",
129+
"model.layers.0.mlp.down_proj.input_quantizer._amax",
130+
"model.layers.0.mlp.down_proj.weight_quantizer._amax",
131+
]
132+
133+
for key in expected_keys:
134+
print(f" {key}")
135+
136+
print("\nActual vLLM keys:")
137+
for key in sorted(vllm_amax_keys):
138+
print(f" {key}")
139+
140+
# Check if all expected keys are present
141+
missing_keys = set(expected_keys) - set(vllm_amax_keys)
142+
extra_keys = set(vllm_amax_keys) - set(expected_keys)
143+
144+
if missing_keys:
145+
print(f"\nMissing keys: {missing_keys}")
146+
if extra_keys:
147+
print(f"\nExtra keys: {extra_keys}")
148+
149+
if not missing_keys and not extra_keys:
150+
print("\n✓ Test passed! All keys converted correctly.")
151+
else:
152+
print("\n✗ Test failed! Key mismatch detected.")
153+
154+
155+
def main():
156+
parser = argparse.ArgumentParser(
157+
description="Convert amax values from HuggingFace to vLLM format"
158+
)
159+
parser.add_argument("--input", "-i", help="Input HuggingFace checkpoint path")
160+
parser.add_argument("--output", "-o", help="Output vLLM checkpoint path")
161+
parser.add_argument("--dry-run", action="store_true", help="Show conversion without saving")
162+
parser.add_argument("--test", action="store_true", help="Run test with sample data")
163+
164+
args = parser.parse_args()
165+
166+
if args.test:
167+
test_conversion()
168+
return
169+
170+
if not args.input or not args.output:
171+
parser.error("--input and --output are required unless using --test")
172+
173+
# Load HuggingFace checkpoint
174+
print(f"Loading HuggingFace checkpoint from: {args.input}")
175+
if os.path.isfile(args.input):
176+
hf_state_dict = torch.load(args.input, map_location="cpu")
177+
else:
178+
raise Exception(f"File not found: {args.input}")
179+
180+
print(f"Loaded {len(hf_state_dict)} keys from HuggingFace checkpoint")
181+
182+
# Filter to only amax keys for analysis
183+
amax_keys = [k for k in hf_state_dict if "_amax" in k]
184+
print(f"Found {len(amax_keys)} amax keys")
185+
186+
if args.dry_run:
187+
print("\nAmax keys in HuggingFace format:")
188+
for key in sorted(amax_keys):
189+
print(f" {key}")
190+
191+
# Convert to vLLM format
192+
print("\nConverting to vLLM format...")
193+
vllm_state_dict = convert_amax_hf2vllm(hf_state_dict)
194+
195+
vllm_amax_keys = [k for k in vllm_state_dict if "_amax" in k]
196+
print(f"Result: {len(vllm_amax_keys)} amax keys in vLLM format")
197+
198+
if args.dry_run:
199+
print("\nAmax keys in vLLM format:")
200+
for key in sorted(vllm_amax_keys):
201+
print(f" {key}")
202+
print("\nDry run complete. No files saved.")
203+
return
204+
205+
# Save vLLM checkpoint
206+
print(f"Saving vLLM checkpoint to: {args.output}")
207+
os.makedirs(os.path.dirname(args.output), exist_ok=True)
208+
torch.save(vllm_state_dict, args.output)
209+
print("Conversion complete!")
210+
211+
212+
if __name__ == "__main__":
213+
main()

0 commit comments

Comments
 (0)