diff --git a/examples/vllm_serve/Dockerfile b/examples/vllm_serve/Dockerfile new file mode 100644 index 000000000..7fa28c5f1 --- /dev/null +++ b/examples/vllm_serve/Dockerfile @@ -0,0 +1,44 @@ +FROM vllm/vllm-openai:v0.10.2 + +# Set environment variables +ENV PIP_NO_CACHE_DIR=off \ + PIP_CONSTRAINT= + +WORKDIR /workspace + +# Install system dependencies needed for modelopt +RUN apt-get update && apt-get install -y \ + git \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Copy the entire TensorRT-Model-Optimizer source code +COPY . TensorRT-Model-Optimizer + +# Remove .git directory to reduce image size +RUN rm -rf TensorRT-Model-Optimizer/.git + +# Install modelopt from local source with all dependencies +RUN cd TensorRT-Model-Optimizer && \ + pip install -e ".[all,dev-test]" + +# Llama4 requires this +RUN pip install flash-attn==2.7.4.post1 + +# Pre-compile CUDA extensions to avoid compilation time during runtime +RUN python3 -c "import modelopt.torch.quantization.extensions as ext; ext.precompile()" || true + +# Install requirements from examples (excluding windows examples) +RUN find TensorRT-Model-Optimizer/examples -name "requirements.txt" | grep -v "windows" | while read req_file; do \ + echo "Installing from $req_file"; \ + pip install -r "$req_file" || echo "Warning: Failed to install from $req_file"; \ + done + +# Allow users to run without root +RUN chmod -R 777 /workspace + +# Override the ENTRYPOINT from the base image to allow flexible usage +ENTRYPOINT [] + +# Set the default command +CMD ["/bin/bash"] diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md new file mode 100644 index 000000000..8d5a5d35e --- /dev/null +++ b/examples/vllm_serve/README.md @@ -0,0 +1,56 @@ +# Serve fakequant models with vLLM + +This is a simple example to demonstrate calibrating and serving ModelOpt fakequant models in vLLM. + +Compared with realquant, fakequant is 2-5x slower, but doesn't require dedicated kernel support and facilitates research. + +This example is tested with vllm 0.9.0 and 0.11.2 + +## Prepare environment + +Follow the following instruction to build a docker environment, or install vllm with pip. + +```bash +docker build -f examples/vllm_serve/Dockerfile -t vllm-modelopt . +``` + +## Calibrate and serve fake quant model in vLLM + +Step 1: Modify `quant_config` in `vllm_serve_fake_quant.py` for the desired quantization format + +Step 2: Run the following command, with all supported flag as `vllm serve`: + +```bash +python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 +``` + +Step 3: test the API server with curl: + +```bash +curl -X POST "http://127.0.0.1:8000/v1/chat/completions" -H "Content-Type: application/json" -d '{ + "model": "", + "messages": [ + {"role": "user", "content": "Hi, what is your name"} + ], + "max_tokens": 8 + }' + +``` + +Step 4 (Optional): using lm_eval to run evaluation + +```bash +lm_eval --model local-completions --tasks gsm8k --model_args model=,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=1,max_retries=3,tokenized_requests=False,batch_size=128,tokenizer_backend=None +``` + +## Load QAT/PTQ model and serve in vLLM (WIP) + +Overwrite the calibrated amax value with prepared values from either PTQ/QAT. This is only tested for Llama3.1 + +Step 1: convert amax to merged amax, using llama3.1 as an example: + +```bash +python convert_amax_hf2vllm.py -i -o +``` + +Step 2: add `` to `quant_config` in `vllm_serve_fakequant.py` diff --git a/examples/vllm_serve/convert_amax_hf2vllm.py b/examples/vllm_serve/convert_amax_hf2vllm.py new file mode 100644 index 000000000..6f0321a91 --- /dev/null +++ b/examples/vllm_serve/convert_amax_hf2vllm.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 + +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import re +from collections import defaultdict + +import torch + + +def convert_amax_hf2vllm( + hf_state_dict: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: + """ + Convert amax values from HuggingFace format to vLLM format. + + This function merges: + - q_proj, k_proj, v_proj amax values into qkv_proj (taking max) + - gate_proj, up_proj amax values into gate_up_proj (taking max) + + Args: + hf_state_dict: HuggingFace state dict containing amax values + + Returns: + vLLM format state dict with merged amax values + """ + vllm_state_dict = {} + + # Group keys by their base pattern (without the specific projection name) + merge_groups = defaultdict(list) + + for key, value in hf_state_dict.items(): + if "_amax" not in key: + # Copy non-amax keys as-is + vllm_state_dict[key] = value + continue + + # Check if this is a q/k/v projection that needs merging + qkv_match = re.search(r"(.*\.)([qkv])_proj(\..+_amax)$", key) + if qkv_match: + base_pattern = qkv_match.group(1) + "qkv_proj" + qkv_match.group(3) + merge_groups[base_pattern].append((key, value)) + continue + + # Check if this is a gate/up projection that needs merging + gate_up_match = re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key) + if gate_up_match: + base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3) + merge_groups[base_pattern].append((key, value)) + continue + + # Copy other amax keys as-is (like o_proj, down_proj) + vllm_state_dict[key] = value + + # Merge grouped amax values by taking the maximum + for merged_key, key_value_pairs in merge_groups.items(): + if len(key_value_pairs) > 1: + # Take the maximum across all values for this merged key + values = [value for _, value in key_value_pairs] + merged_value = torch.stack(values).max(dim=0)[0] + vllm_state_dict[merged_key] = merged_value + print(f"Merged {len(key_value_pairs)} keys into {merged_key}") + for orig_key, _ in key_value_pairs: + print(f" - {orig_key}") + else: + # Single key, just rename it + _, value = key_value_pairs[0] + vllm_state_dict[merged_key] = value + + return vllm_state_dict + + +def test_conversion(): + """Test the conversion logic with sample keys""" + import torch + + # Create sample HF state dict + sample_hf_keys = [ + "model.layers.0.self_attn.q_proj.input_quantizer._amax", + "model.layers.0.self_attn.k_proj.input_quantizer._amax", + "model.layers.0.self_attn.v_proj.input_quantizer._amax", + "model.layers.0.self_attn.q_proj.weight_quantizer._amax", + "model.layers.0.self_attn.k_proj.weight_quantizer._amax", + "model.layers.0.self_attn.v_proj.weight_quantizer._amax", + "model.layers.0.self_attn.o_proj.input_quantizer._amax", + "model.layers.0.self_attn.o_proj.weight_quantizer._amax", + "model.layers.0.mlp.gate_proj.input_quantizer._amax", + "model.layers.0.mlp.up_proj.input_quantizer._amax", + "model.layers.0.mlp.gate_proj.weight_quantizer._amax", + "model.layers.0.mlp.up_proj.weight_quantizer._amax", + "model.layers.0.mlp.down_proj.input_quantizer._amax", + "model.layers.0.mlp.down_proj.weight_quantizer._amax", + ] + + hf_state_dict = {} + for key in sample_hf_keys: + hf_state_dict[key] = torch.tensor([1.0, 2.0, 3.0]) # Sample values + + print("Testing conversion with sample keys...") + print(f"Input keys: {len(sample_hf_keys)}") + + vllm_state_dict = convert_amax_hf2vllm(hf_state_dict) + vllm_amax_keys = [k for k in vllm_state_dict if "_amax" in k] + + print(f"Output keys: {len(vllm_amax_keys)}") + print("\nExpected vLLM keys:") + expected_keys = [ + "model.layers.0.self_attn.qkv_proj.input_quantizer._amax", + "model.layers.0.self_attn.qkv_proj.weight_quantizer._amax", + "model.layers.0.self_attn.o_proj.input_quantizer._amax", + "model.layers.0.self_attn.o_proj.weight_quantizer._amax", + "model.layers.0.mlp.gate_up_proj.input_quantizer._amax", + "model.layers.0.mlp.gate_up_proj.weight_quantizer._amax", + "model.layers.0.mlp.down_proj.input_quantizer._amax", + "model.layers.0.mlp.down_proj.weight_quantizer._amax", + ] + + for key in expected_keys: + print(f" {key}") + + print("\nActual vLLM keys:") + for key in sorted(vllm_amax_keys): + print(f" {key}") + + # Check if all expected keys are present + missing_keys = set(expected_keys) - set(vllm_amax_keys) + extra_keys = set(vllm_amax_keys) - set(expected_keys) + + if missing_keys: + print(f"\nMissing keys: {missing_keys}") + if extra_keys: + print(f"\nExtra keys: {extra_keys}") + + if not missing_keys and not extra_keys: + print("\n✓ Test passed! All keys converted correctly.") + else: + print("\n✗ Test failed! Key mismatch detected.") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert amax values from HuggingFace to vLLM format" + ) + parser.add_argument("--input", "-i", help="Input HuggingFace checkpoint path") + parser.add_argument("--output", "-o", help="Output vLLM checkpoint path") + parser.add_argument("--dry-run", action="store_true", help="Show conversion without saving") + parser.add_argument("--test", action="store_true", help="Run test with sample data") + + args = parser.parse_args() + + if args.test: + test_conversion() + return + + if not args.input or not args.output: + parser.error("--input and --output are required unless using --test") + + # Load HuggingFace checkpoint + print(f"Loading HuggingFace checkpoint from: {args.input}") + if os.path.isfile(args.input): + hf_state_dict = torch.load(args.input, map_location="cpu") + else: + raise Exception(f"File not found: {args.input}") + + print(f"Loaded {len(hf_state_dict)} keys from HuggingFace checkpoint") + + # Filter to only amax keys for analysis + amax_keys = [k for k in hf_state_dict if "_amax" in k] + print(f"Found {len(amax_keys)} amax keys") + + if args.dry_run: + print("\nAmax keys in HuggingFace format:") + for key in sorted(amax_keys): + print(f" {key}") + + # Convert to vLLM format + print("\nConverting to vLLM format...") + vllm_state_dict = convert_amax_hf2vllm(hf_state_dict) + + vllm_amax_keys = [k for k in vllm_state_dict if "_amax" in k] + print(f"Result: {len(vllm_amax_keys)} amax keys in vLLM format") + + if args.dry_run: + print("\nAmax keys in vLLM format:") + for key in sorted(vllm_amax_keys): + print(f" {key}") + print("\nDry run complete. No files saved.") + return + + # Save vLLM checkpoint + print(f"Saving vLLM checkpoint to: {args.output}") + os.makedirs(os.path.dirname(args.output), exist_ok=True) + torch.save(vllm_state_dict, args.output) + print("Conversion complete!") + + +if __name__ == "__main__": + main() diff --git a/examples/vllm_serve/vllm_serve_fakequant.py b/examples/vllm_serve/vllm_serve_fakequant.py new file mode 100644 index 000000000..e96f2d3dc --- /dev/null +++ b/examples/vllm_serve/vllm_serve_fakequant.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2023 Deep Cognition and Language Research (DeCLaRe) Lab +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from typing import Any + +import torch +import uvloop +from tqdm import tqdm +from transformers import AutoTokenizer +from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.sampling_params import SamplingParams +from vllm.sequence import IntermediateTensors +from vllm.utils import FlexibleArgumentParser +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput + +# from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.gpu_worker import Worker + +import modelopt.torch.quantization as mtq +from modelopt.torch.utils.dataset_utils import get_dataset_dataloader + + +@contextmanager +def disable_compilation(model): + """Context manager to temporarily disable torch.compile""" + do_not_compile = True + if hasattr(model, "model"): + do_not_compile = model.model.do_not_compile + model.model.do_not_compile = True + elif hasattr(model, "language_model"): # VLM requires this + do_not_compile = model.language_model.model.do_not_compile + model.language_model.model.do_not_compile = True + else: + raise ValueError("Model does not have a model or language_model attribute") + + try: + yield + finally: + if hasattr(model, "model"): + model.model.do_not_compile = do_not_compile + elif hasattr(model, "language_model"): + model.language_model.model.do_not_compile = do_not_compile + + +quant_config: dict[str, Any] = { + "quant_dataset": "cnn_dailymail", + "quant_num_samples": 512, + "quant_format": "NVFP4_DEFAULT_CFG", + "amax_file_path": None, # Optional: path to pre-computed amax values (e.g., "/path/to/amax.pt") +} + + +def fakequant_run_prolog(self): + tokenizer = AutoTokenizer.from_pretrained( + self.model_config.tokenizer, + trust_remote_code=True, + ) + if tokenizer.pad_token != "" or tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + if quant_config["amax_file_path"]: + # If amax file path is provided, we only need to do a simple calibration step + quant_config["quant_num_samples"] = 1 + + calib_dataloader = get_dataset_dataloader( + dataset_name=quant_config["quant_dataset"], + tokenizer=tokenizer, + batch_size=1, + num_samples=quant_config["quant_num_samples"], + device=self.device, + ) + + def calibrate_loop(model: Any = None) -> None: + print("Calibrating model...") + for batch_idx, batch in tqdm(enumerate(calib_dataloader)): + input_ids = batch["input_ids"][0] + + # Convert tensor to list of integers for vLLM compatibility + if torch.is_tensor(input_ids): + input_ids_list = input_ids.cpu().tolist() + else: + input_ids_list = list(input_ids) + + num_groups = len(self.kv_cache_config.kv_cache_groups) + empty_block_ids = tuple([] for _ in range(num_groups)) + + # Build the per-request payload the model runner normally receives from the scheduler. + req_id = f"req-{batch_idx}" + new_req = NewRequestData( + req_id=req_id, + prompt_token_ids=input_ids_list, + mm_kwargs=[], + mm_hashes=[], + mm_positions=[], + sampling_params=SamplingParams(max_tokens=1), + pooling_params=None, + block_ids=empty_block_ids, + num_computed_tokens=0, + lora_request=None, + ) + + # Assemble a SchedulerOutput with all KV-related fields left empty. + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[new_req], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={req_id: len(input_ids_list)}, + total_num_scheduled_tokens=len(input_ids_list), + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0] * num_groups, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None, + kv_connector_metadata=None, + ) + intermediate_tensors = None + forward_pass = scheduler_output.total_num_scheduled_tokens > 0 + if forward_pass and not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict(all_gather_group=get_tp_group()) + ) + self.execute_model(scheduler_output, intermediate_tensors=intermediate_tensors) + + quant_cfg = getattr(mtq, quant_config["quant_format"]) + + with disable_compilation(self.model): + mtq.quantize(self.model, quant_cfg, forward_loop=calibrate_loop) + + # Only print on rank 0 to avoid duplicate output in distributed setups + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + print(self.model) + + # Override amax values from saved state_dict + amax_file_path = quant_config["amax_file_path"] + if amax_file_path: + print(f"Loading amax values from {amax_file_path}") + saved_amax_dict = torch.load(amax_file_path, map_location=self.device) + current_state_dict = self.model.state_dict() + + # Count amax keys in checkpoint and model + checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("amax")] + model_amax_keys = [key for key in current_state_dict if key.endswith("amax")] + + checkpoint_amax_count = len(checkpoint_amax_keys) + model_amax_count = len(model_amax_keys) + + # Ensure counts match + if checkpoint_amax_count != model_amax_count: + raise ValueError( + f"Mismatch in amax key counts: checkpoint has {checkpoint_amax_count} " + f"amax keys but model has {model_amax_count} amax keys. " + ) + + # Update amax values + for key, value in saved_amax_dict.items(): + if key in current_state_dict: + current_state_dict[key] = value.to(self.device) + + self.model.load_state_dict(current_state_dict, strict=True) + + mtq.fold_weight(self.model) + + +# Store the original profile_run method +old_determine_available_memory = Worker.determine_available_memory +old_compile_or_warm_up_model = Worker.compile_or_warm_up_model + + +# Define new profile_run that includes our modifications +def new_determine_available_memory(self) -> None: + with disable_compilation(self.model_runner.model): + results = old_determine_available_memory(self) + return results + + +def new_compile_or_warm_up_model(self): + if quant_config["quant_format"]: + fakequant_run_prolog(self.model_runner) + old_compile_or_warm_up_model(self) + + +# To make sure this monkey patch can be propagated to subprocess, +# Do not put this into functions! +Worker.determine_available_memory = new_determine_available_memory +Worker.compile_or_warm_up_model = new_compile_or_warm_up_model + + +def main(): + # Create parser that handles both quant and serve arguments + parser = FlexibleArgumentParser(description="vLLM model server with quantization support") + parser.add_argument("model", type=str, help="The path or name of the model to serve") + parser = make_arg_parser(parser) + + # Parse arguments + args = parser.parse_args() + # Run the server + uvloop.run(run_server(args)) + + +if __name__ == "__main__": + main()