Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion examples/vllm_serve/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@ docker build -f examples/vllm_serve/Dockerfile -t vllm-modelopt .

## Calibrate and serve fake quant model in vLLM

Step 1: Modify `quant_config` in `vllm_serve_fake_quant.py` for the desired quantization format
Step 1: Configure quantization settings.
You can either edit the `quant_config` dictionary in `vllm_serve_fakequant.py`, or set the following environment variables to control quantization behavior:

| Variable | Description | Default |
|-----------------|--------------------------------------------------|---------------------|
| QUANT_DATASET | Dataset name for calibration | cnn_dailymail |
| QUANT_NUM_SAMPLES| Number of samples used for calibration | 512 |
| QUANT_FORMAT | Quantization format | NVFP4_DEFAULT_CFG |
| AMAX_FILE_PATH | Optional path to amax file (for loading amax) | None |

Set these variables in your shell or Docker environment as needed to customize calibration.

Step 2: Run the following command, with all supported flag as `vllm serve`:

Expand Down Expand Up @@ -58,3 +68,4 @@ Step 2: add `<vllm_amax.pth>` to `quant_config` in `vllm_serve_fakequant.py`
## Know Problems

1. AWQ is not yet supported in vLLM.
2. Amax sync across TP/EP is not handled now.
221 changes: 221 additions & 0 deletions examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import os
import warnings
from contextlib import contextmanager
from typing import Any

import torch
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.worker.gpu_worker import Worker as BaseWorker

import modelopt.torch.quantization as mtq
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader


@contextmanager
def disable_compilation(model):
do_not_compile = True
if hasattr(model, "model"):
do_not_compile = model.model.do_not_compile
model.model.do_not_compile = True
elif hasattr(model, "language_model"):
do_not_compile = model.language_model.model.do_not_compile
model.language_model.model.do_not_compile = True
else:
raise ValueError("Model does not have a model or language_model attribute")

try:
yield
finally:
if hasattr(model, "model"):
model.model.do_not_compile = do_not_compile
elif hasattr(model, "language_model"):
model.language_model.model.do_not_compile = do_not_compile


quant_config: dict[str, Any] = {
"quant_dataset": os.environ.get("QUANT_DATASET", "cnn_dailymail"),
"quant_num_samples": int(os.environ.get("QUANT_NUM_SAMPLES", 512)),
"quant_format": os.environ.get("QUANT_FORMAT", "NVFP4_DEFAULT_CFG"),
"amax_file_path": os.environ.get("AMAX_FILE_PATH", None),
}


def _create_new_data_cls(data_cls, **kwargs):
"""vLLM's low-level API changes frequently. This function creates a class with parameters
compatible with the different vLLM versions."""
valid_params = {field.name for field in dataclasses.fields(data_cls)}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
return data_cls(**filtered_kwargs)


def _fakequant_run_prolog_worker(self) -> None:
tokenizer = AutoTokenizer.from_pretrained(
self.model_runner.model_config.tokenizer,
trust_remote_code=True,
)
if tokenizer.pad_token != "<unk>" or tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

if quant_config["amax_file_path"]:
print("Will load amax, so only do a single sample calibration")
quant_config["quant_num_samples"] = 1

calib_dataloader = get_dataset_dataloader(
dataset_name=quant_config["quant_dataset"],
tokenizer=tokenizer,
batch_size=1,
num_samples=quant_config["quant_num_samples"],
device=self.device,
)

def calibrate_loop(model: Any = None) -> None:
for batch_idx, batch in tqdm(enumerate(calib_dataloader)):
input_ids = batch["input_ids"][0]

# Convert tensor to list of integers for vLLM compatibility
if torch.is_tensor(input_ids):
input_ids_list = input_ids.cpu().tolist()
else:
input_ids_list = list(input_ids)

num_groups = len(self.model_runner.kv_cache_config.kv_cache_groups)
empty_block_ids = tuple([] for _ in range(num_groups))

req_id = f"req-{batch_idx}"
# Pass all possible parameters - the helper will filter based on vLLM version
new_req = _create_new_data_cls(
NewRequestData,
req_id=req_id,
prompt_token_ids=input_ids_list,
# Old API parameters
mm_kwargs=[], # TODO: remove this when vllm <= 0.11 is outdated
mm_hashes=[], # TODO: remove this when vllm <= 0.11 is outdated
mm_positions=[], # TODO: remove this when vllm <= 0.11 is outdated
# New API parameter
mm_features=[],
sampling_params=SamplingParams(max_tokens=1),
pooling_params=None,
block_ids=empty_block_ids,
num_computed_tokens=0,
lora_request=None,
)

scheduler_output = _create_new_data_cls(
SchedulerOutput,
scheduled_new_reqs=[new_req],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_id: len(input_ids_list)},
total_num_scheduled_tokens=len(input_ids_list),
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[0] * num_groups,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
kv_connector_metadata=None,
# Old API parameters
structured_output_request_ids={}, # TODO: remove this when vllm <= 0.11 is outdated
grammar_bitmask=None, # TODO: remove this when vllm <= 0.11 is outdated
)
output = self.execute_model(scheduler_output)
if hasattr(self, "sample_tokens"):
if output is None: # TODO: make this default when vllm <= 0.11 is outdated
self.sample_tokens(None)

quant_cfg = getattr(mtq, quant_config["quant_format"])

model = self.model_runner.model
if hasattr(model, "unwrap"):
model = model.unwrap()

with disable_compilation(model):
print("quantizing model...")
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)

amax_file_path = quant_config["amax_file_path"]
if amax_file_path:
print(f"Loading amax values from {amax_file_path}")
saved_amax_dict = torch.load(amax_file_path)
current_state_dict = model.state_dict()

# Count amax keys in checkpoint and model
checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("_amax")]
model_amax_keys = [key for key in current_state_dict if key.endswith("_amax")]
for key in checkpoint_amax_keys:
if key not in model_amax_keys:
print(f"Key {key} not found in model state dict, but exists in checkpoint")
for key in model_amax_keys:
if key not in checkpoint_amax_keys:
raise ValueError(
f"Key {key} not found in checkpoint state dict, but exists in model"
)

checkpoint_amax_count = len(checkpoint_amax_keys)
model_amax_count = len(model_amax_keys)

# Ensure counts match
if checkpoint_amax_count != model_amax_count:
warnings.warn(
f"Mismatch in amax key counts: checkpoint has {checkpoint_amax_count} "
f"amax keys but model has {model_amax_count} amax keys. This can happen if the model is using PP."
)

# Update amax values
for key, value in saved_amax_dict.items():
if key in current_state_dict:
current_state_dict[key] = value.to(current_state_dict[key].device)

model.load_state_dict(current_state_dict)
torch.distributed.barrier()

if amax_file_path is None:
# Sync amax across TP can be done here if needed
pass
# for name, buffer in model.named_buffers():
# if name.endswith("_amax"):
# print("syncing amax across TP for", name)
# torch.distributed.all_reduce(
# buffer, op=torch.distributed.ReduceOp.MAX, group=get_tp_group().device_group
# )
# torch.distributed.barrier()

if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
mtq.print_quant_summary(model)

mtq.fold_weight(model)
for name, module in model.named_modules():
if name.endswith("weight_quantizer"):
assert not module.is_enabled, f"quantizer {name} is still enabled"
Comment on lines +186 to +206
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do this under disable_compilation context?



class FakeQuantWorker(BaseWorker):
@torch.inference_mode()
def determine_available_memory(self) -> int:
model = self.model_runner.model
if hasattr(model, "unwrap"):
model = model.unwrap()
with disable_compilation(model):
return super().determine_available_memory()

def compile_or_warm_up_model(self) -> None:
if quant_config["quant_format"]:
_fakequant_run_prolog_worker(self)
super().compile_or_warm_up_model()
Loading
Loading