Skip to content

Commit 72e2167

Browse files
authored
Fix/Improve vllm PTQ and Support multi-node with ray (#534)
## What does this PR do? **Type of change:** Bug fix <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Fix or improve the vllm PTQ. 1. Now support ray, and can run on multiple nodes. 2. MoE typo, and better folding weight for large MoE layers. 3. Add the layer `SharedFusedMoE` 4. Support vllm > 0.11 (not released yet) 5. Add os env to specify quant configs ## Usage <!-- You can potentially add a usage example below. --> ## Testing Tested with latest vllm. ## Additional Information <!-- E.g. related issue. --> The vllm >0.11.0 changed the low-level API significantly. Some changes needs to be removed when vllm<=0.11.0 is outdated. --------- Signed-off-by: mxin <[email protected]>
1 parent a113bea commit 72e2167

File tree

4 files changed

+327
-189
lines changed

4 files changed

+327
-189
lines changed

examples/vllm_serve/README.md

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,17 @@ docker build -f examples/vllm_serve/Dockerfile -t vllm-modelopt .
1616

1717
## Calibrate and serve fake quant model in vLLM
1818

19-
Step 1: Modify `quant_config` in `vllm_serve_fake_quant.py` for the desired quantization format
19+
Step 1: Configure quantization settings.
20+
You can either edit the `quant_config` dictionary in `vllm_serve_fakequant.py`, or set the following environment variables to control quantization behavior:
21+
22+
| Variable | Description | Default |
23+
|-----------------|--------------------------------------------------|---------------------|
24+
| QUANT_DATASET | Dataset name for calibration | cnn_dailymail |
25+
| QUANT_CALIB_SIZE| Number of samples used for calibration | 512 |
26+
| QUANT_CFG | Quantization format | NVFP4_DEFAULT_CFG |
27+
| AMAX_FILE_PATH | Optional path to amax file (for loading amax) | None |
28+
29+
Set these variables in your shell or Docker environment as needed to customize calibration.
2030

2131
Step 2: Run the following command, with all supported flag as `vllm serve`:
2232

@@ -55,6 +65,23 @@ python convert_amax_hf2vllm.py -i <amax.pth> -o <vllm_amax.pth>
5565

5666
Step 2: add `<vllm_amax.pth>` to `quant_config` in `vllm_serve_fakequant.py`
5767

58-
## Know Problems
68+
## Important Notes
69+
70+
**Amax Synchronization across Tensor Parallel (TP):**
71+
72+
- **For non-per-tensor quantization**: It is **recommended** to use an amax file (via `AMAX_FILE_PATH`) because amax synchronization across TP/EP is not automatically handled. Without an amax file, the amax values can be different across different TP ranks, leading to inconsistent results compared to real-quantization.
73+
74+
- **For per-tensor quantization**: If you are not using an amax file, you need to enable amax synchronization across TP ranks. An example implementation is provided in `fakequant_worker.py` (lines 190-198):
75+
76+
```python
77+
for name, buffer in model.named_buffers():
78+
if name.endswith("_amax"):
79+
torch.distributed.all_reduce(
80+
buffer, op=torch.distributed.ReduceOp.MAX, group=get_tp_group().device_group
81+
)
82+
torch.distributed.barrier()
83+
```
84+
85+
## Known Problems
5986

6087
1. AWQ is not yet supported in vLLM.
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import dataclasses
17+
import os
18+
import warnings
19+
from contextlib import contextmanager
20+
from typing import Any
21+
22+
import torch
23+
from tqdm import tqdm
24+
from transformers import AutoTokenizer
25+
from vllm.sampling_params import SamplingParams
26+
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
27+
from vllm.v1.worker.gpu_worker import Worker as BaseWorker
28+
29+
import modelopt.torch.quantization as mtq
30+
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
31+
32+
33+
@contextmanager
34+
def disable_compilation(model):
35+
do_not_compile = True
36+
if hasattr(model, "model"):
37+
do_not_compile = model.model.do_not_compile
38+
model.model.do_not_compile = True
39+
elif hasattr(model, "language_model"):
40+
do_not_compile = model.language_model.model.do_not_compile
41+
model.language_model.model.do_not_compile = True
42+
else:
43+
raise ValueError("Model does not have a model or language_model attribute")
44+
45+
try:
46+
yield
47+
finally:
48+
if hasattr(model, "model"):
49+
model.model.do_not_compile = do_not_compile
50+
elif hasattr(model, "language_model"):
51+
model.language_model.model.do_not_compile = do_not_compile
52+
53+
54+
quant_config: dict[str, Any] = {
55+
"dataset": os.environ.get("QUANT_DATASET", "cnn_dailymail"),
56+
"calib_size": int(os.environ.get("QUANT_CALIB_SIZE", 512)),
57+
"quant_cfg": os.environ.get("QUANT_CFG", "NVFP4_DEFAULT_CFG"),
58+
"amax_file_path": os.environ.get("AMAX_FILE_PATH", None),
59+
}
60+
61+
62+
def _create_new_data_cls(data_cls, **kwargs):
63+
"""vLLM's low-level API changes frequently. This function creates a class with parameters
64+
compatible with the different vLLM versions."""
65+
valid_params = {field.name for field in dataclasses.fields(data_cls)}
66+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
67+
return data_cls(**filtered_kwargs)
68+
69+
70+
def _fakequant_run_prolog_worker(self) -> None:
71+
tokenizer = AutoTokenizer.from_pretrained(
72+
self.model_runner.model_config.tokenizer,
73+
trust_remote_code=True,
74+
)
75+
if tokenizer.pad_token != "<unk>" or tokenizer.pad_token is None:
76+
tokenizer.pad_token = tokenizer.eos_token
77+
78+
if quant_config["amax_file_path"]:
79+
print("Will load amax, so only do a single sample calibration")
80+
quant_config["calib_size"] = 1
81+
82+
calib_dataloader = get_dataset_dataloader(
83+
dataset_name=quant_config["dataset"],
84+
tokenizer=tokenizer,
85+
batch_size=1,
86+
num_samples=quant_config["calib_size"],
87+
device=self.device,
88+
)
89+
90+
def calibrate_loop(model: Any = None) -> None:
91+
for batch_idx, batch in tqdm(enumerate(calib_dataloader)):
92+
input_ids = batch["input_ids"][0]
93+
94+
# Convert tensor to list of integers for vLLM compatibility
95+
if torch.is_tensor(input_ids):
96+
input_ids_list = input_ids.cpu().tolist()
97+
else:
98+
input_ids_list = list(input_ids)
99+
100+
num_groups = len(self.model_runner.kv_cache_config.kv_cache_groups)
101+
empty_block_ids = tuple([] for _ in range(num_groups))
102+
103+
req_id = f"req-{batch_idx}"
104+
# Pass all possible parameters - the helper will filter based on vLLM version
105+
new_req = _create_new_data_cls(
106+
NewRequestData,
107+
req_id=req_id,
108+
prompt_token_ids=input_ids_list,
109+
# Old API parameters
110+
mm_kwargs=[], # TODO: remove this when vllm <= 0.11 is outdated
111+
mm_hashes=[], # TODO: remove this when vllm <= 0.11 is outdated
112+
mm_positions=[], # TODO: remove this when vllm <= 0.11 is outdated
113+
# New API parameter
114+
mm_features=[],
115+
sampling_params=SamplingParams(max_tokens=1),
116+
pooling_params=None,
117+
block_ids=empty_block_ids,
118+
num_computed_tokens=0,
119+
lora_request=None,
120+
)
121+
122+
scheduler_output = _create_new_data_cls(
123+
SchedulerOutput,
124+
scheduled_new_reqs=[new_req],
125+
scheduled_cached_reqs=CachedRequestData.make_empty(),
126+
num_scheduled_tokens={req_id: len(input_ids_list)},
127+
total_num_scheduled_tokens=len(input_ids_list),
128+
scheduled_spec_decode_tokens={},
129+
scheduled_encoder_inputs={},
130+
num_common_prefix_blocks=[0] * num_groups,
131+
finished_req_ids=set(),
132+
free_encoder_mm_hashes=[],
133+
kv_connector_metadata=None,
134+
# Old API parameters
135+
structured_output_request_ids={}, # TODO: remove this when vllm <= 0.11 is outdated
136+
grammar_bitmask=None, # TODO: remove this when vllm <= 0.11 is outdated
137+
)
138+
output = self.execute_model(scheduler_output)
139+
if hasattr(self, "sample_tokens"):
140+
if output is None: # TODO: make this default when vllm <= 0.11 is outdated
141+
self.sample_tokens(None)
142+
143+
quant_cfg = getattr(mtq, quant_config["quant_cfg"])
144+
145+
model = self.model_runner.model
146+
if hasattr(model, "unwrap"):
147+
model = model.unwrap()
148+
149+
with disable_compilation(model):
150+
print("quantizing model...")
151+
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
152+
153+
amax_file_path = quant_config["amax_file_path"]
154+
if amax_file_path:
155+
print(f"Loading amax values from {amax_file_path}")
156+
saved_amax_dict = torch.load(amax_file_path)
157+
current_state_dict = model.state_dict()
158+
159+
# Count amax keys in checkpoint and model
160+
checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("_amax")]
161+
model_amax_keys = [key for key in current_state_dict if key.endswith("_amax")]
162+
for key in checkpoint_amax_keys:
163+
if key not in model_amax_keys:
164+
print(f"Key {key} not found in model state dict, but exists in checkpoint")
165+
for key in model_amax_keys:
166+
if key not in checkpoint_amax_keys:
167+
raise ValueError(
168+
f"Key {key} not found in checkpoint state dict, but exists in model"
169+
)
170+
171+
checkpoint_amax_count = len(checkpoint_amax_keys)
172+
model_amax_count = len(model_amax_keys)
173+
174+
# Ensure counts match
175+
if checkpoint_amax_count != model_amax_count:
176+
warnings.warn(
177+
f"Mismatch in amax key counts: checkpoint has {checkpoint_amax_count} "
178+
f"amax keys but model has {model_amax_count} amax keys. This can happen if the model is using PP."
179+
)
180+
181+
# Update amax values
182+
for key, value in saved_amax_dict.items():
183+
if key in current_state_dict:
184+
current_state_dict[key] = value.to(current_state_dict[key].device)
185+
186+
model.load_state_dict(current_state_dict)
187+
torch.distributed.barrier()
188+
189+
if amax_file_path is None:
190+
# Sync amax across TP can be done here if needed
191+
pass
192+
# for name, buffer in model.named_buffers():
193+
# if name.endswith("_amax"):
194+
# print("syncing amax across TP for", name)
195+
# torch.distributed.all_reduce(
196+
# buffer, op=torch.distributed.ReduceOp.MAX, group=get_tp_group().device_group
197+
# )
198+
# torch.distributed.barrier()
199+
200+
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
201+
mtq.print_quant_summary(model)
202+
203+
mtq.fold_weight(model)
204+
for name, module in model.named_modules():
205+
if name.endswith("weight_quantizer"):
206+
assert not module.is_enabled, f"quantizer {name} is still enabled"
207+
208+
209+
class FakeQuantWorker(BaseWorker):
210+
@torch.inference_mode()
211+
def determine_available_memory(self) -> int:
212+
model = self.model_runner.model
213+
if hasattr(model, "unwrap"):
214+
model = model.unwrap()
215+
with disable_compilation(model):
216+
return super().determine_available_memory()
217+
218+
def compile_or_warm_up_model(self) -> None:
219+
if quant_config["quant_cfg"]:
220+
_fakequant_run_prolog_worker(self)
221+
super().compile_or_warm_up_model()

0 commit comments

Comments
 (0)