|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import dataclasses |
| 17 | +import os |
| 18 | +import warnings |
| 19 | +from contextlib import contextmanager |
| 20 | +from typing import Any |
| 21 | + |
| 22 | +import torch |
| 23 | +from tqdm import tqdm |
| 24 | +from transformers import AutoTokenizer |
| 25 | +from vllm.sampling_params import SamplingParams |
| 26 | +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput |
| 27 | +from vllm.v1.worker.gpu_worker import Worker as BaseWorker |
| 28 | + |
| 29 | +import modelopt.torch.quantization as mtq |
| 30 | +from modelopt.torch.utils.dataset_utils import get_dataset_dataloader |
| 31 | + |
| 32 | + |
| 33 | +@contextmanager |
| 34 | +def disable_compilation(model): |
| 35 | + do_not_compile = True |
| 36 | + if hasattr(model, "model"): |
| 37 | + do_not_compile = model.model.do_not_compile |
| 38 | + model.model.do_not_compile = True |
| 39 | + elif hasattr(model, "language_model"): |
| 40 | + do_not_compile = model.language_model.model.do_not_compile |
| 41 | + model.language_model.model.do_not_compile = True |
| 42 | + else: |
| 43 | + raise ValueError("Model does not have a model or language_model attribute") |
| 44 | + |
| 45 | + try: |
| 46 | + yield |
| 47 | + finally: |
| 48 | + if hasattr(model, "model"): |
| 49 | + model.model.do_not_compile = do_not_compile |
| 50 | + elif hasattr(model, "language_model"): |
| 51 | + model.language_model.model.do_not_compile = do_not_compile |
| 52 | + |
| 53 | + |
| 54 | +quant_config: dict[str, Any] = { |
| 55 | + "dataset": os.environ.get("QUANT_DATASET", "cnn_dailymail"), |
| 56 | + "calib_size": int(os.environ.get("QUANT_CALIB_SIZE", 512)), |
| 57 | + "quant_cfg": os.environ.get("QUANT_CFG", "NVFP4_DEFAULT_CFG"), |
| 58 | + "amax_file_path": os.environ.get("AMAX_FILE_PATH", None), |
| 59 | +} |
| 60 | + |
| 61 | + |
| 62 | +def _create_new_data_cls(data_cls, **kwargs): |
| 63 | + """vLLM's low-level API changes frequently. This function creates a class with parameters |
| 64 | + compatible with the different vLLM versions.""" |
| 65 | + valid_params = {field.name for field in dataclasses.fields(data_cls)} |
| 66 | + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} |
| 67 | + return data_cls(**filtered_kwargs) |
| 68 | + |
| 69 | + |
| 70 | +def _fakequant_run_prolog_worker(self) -> None: |
| 71 | + tokenizer = AutoTokenizer.from_pretrained( |
| 72 | + self.model_runner.model_config.tokenizer, |
| 73 | + trust_remote_code=True, |
| 74 | + ) |
| 75 | + if tokenizer.pad_token != "<unk>" or tokenizer.pad_token is None: |
| 76 | + tokenizer.pad_token = tokenizer.eos_token |
| 77 | + |
| 78 | + if quant_config["amax_file_path"]: |
| 79 | + print("Will load amax, so only do a single sample calibration") |
| 80 | + quant_config["calib_size"] = 1 |
| 81 | + |
| 82 | + calib_dataloader = get_dataset_dataloader( |
| 83 | + dataset_name=quant_config["dataset"], |
| 84 | + tokenizer=tokenizer, |
| 85 | + batch_size=1, |
| 86 | + num_samples=quant_config["calib_size"], |
| 87 | + device=self.device, |
| 88 | + ) |
| 89 | + |
| 90 | + def calibrate_loop(model: Any = None) -> None: |
| 91 | + for batch_idx, batch in tqdm(enumerate(calib_dataloader)): |
| 92 | + input_ids = batch["input_ids"][0] |
| 93 | + |
| 94 | + # Convert tensor to list of integers for vLLM compatibility |
| 95 | + if torch.is_tensor(input_ids): |
| 96 | + input_ids_list = input_ids.cpu().tolist() |
| 97 | + else: |
| 98 | + input_ids_list = list(input_ids) |
| 99 | + |
| 100 | + num_groups = len(self.model_runner.kv_cache_config.kv_cache_groups) |
| 101 | + empty_block_ids = tuple([] for _ in range(num_groups)) |
| 102 | + |
| 103 | + req_id = f"req-{batch_idx}" |
| 104 | + # Pass all possible parameters - the helper will filter based on vLLM version |
| 105 | + new_req = _create_new_data_cls( |
| 106 | + NewRequestData, |
| 107 | + req_id=req_id, |
| 108 | + prompt_token_ids=input_ids_list, |
| 109 | + # Old API parameters |
| 110 | + mm_kwargs=[], # TODO: remove this when vllm <= 0.11 is outdated |
| 111 | + mm_hashes=[], # TODO: remove this when vllm <= 0.11 is outdated |
| 112 | + mm_positions=[], # TODO: remove this when vllm <= 0.11 is outdated |
| 113 | + # New API parameter |
| 114 | + mm_features=[], |
| 115 | + sampling_params=SamplingParams(max_tokens=1), |
| 116 | + pooling_params=None, |
| 117 | + block_ids=empty_block_ids, |
| 118 | + num_computed_tokens=0, |
| 119 | + lora_request=None, |
| 120 | + ) |
| 121 | + |
| 122 | + scheduler_output = _create_new_data_cls( |
| 123 | + SchedulerOutput, |
| 124 | + scheduled_new_reqs=[new_req], |
| 125 | + scheduled_cached_reqs=CachedRequestData.make_empty(), |
| 126 | + num_scheduled_tokens={req_id: len(input_ids_list)}, |
| 127 | + total_num_scheduled_tokens=len(input_ids_list), |
| 128 | + scheduled_spec_decode_tokens={}, |
| 129 | + scheduled_encoder_inputs={}, |
| 130 | + num_common_prefix_blocks=[0] * num_groups, |
| 131 | + finished_req_ids=set(), |
| 132 | + free_encoder_mm_hashes=[], |
| 133 | + kv_connector_metadata=None, |
| 134 | + # Old API parameters |
| 135 | + structured_output_request_ids={}, # TODO: remove this when vllm <= 0.11 is outdated |
| 136 | + grammar_bitmask=None, # TODO: remove this when vllm <= 0.11 is outdated |
| 137 | + ) |
| 138 | + output = self.execute_model(scheduler_output) |
| 139 | + if hasattr(self, "sample_tokens"): |
| 140 | + if output is None: # TODO: make this default when vllm <= 0.11 is outdated |
| 141 | + self.sample_tokens(None) |
| 142 | + |
| 143 | + quant_cfg = getattr(mtq, quant_config["quant_cfg"]) |
| 144 | + |
| 145 | + model = self.model_runner.model |
| 146 | + if hasattr(model, "unwrap"): |
| 147 | + model = model.unwrap() |
| 148 | + |
| 149 | + with disable_compilation(model): |
| 150 | + print("quantizing model...") |
| 151 | + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) |
| 152 | + |
| 153 | + amax_file_path = quant_config["amax_file_path"] |
| 154 | + if amax_file_path: |
| 155 | + print(f"Loading amax values from {amax_file_path}") |
| 156 | + saved_amax_dict = torch.load(amax_file_path) |
| 157 | + current_state_dict = model.state_dict() |
| 158 | + |
| 159 | + # Count amax keys in checkpoint and model |
| 160 | + checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("_amax")] |
| 161 | + model_amax_keys = [key for key in current_state_dict if key.endswith("_amax")] |
| 162 | + for key in checkpoint_amax_keys: |
| 163 | + if key not in model_amax_keys: |
| 164 | + print(f"Key {key} not found in model state dict, but exists in checkpoint") |
| 165 | + for key in model_amax_keys: |
| 166 | + if key not in checkpoint_amax_keys: |
| 167 | + raise ValueError( |
| 168 | + f"Key {key} not found in checkpoint state dict, but exists in model" |
| 169 | + ) |
| 170 | + |
| 171 | + checkpoint_amax_count = len(checkpoint_amax_keys) |
| 172 | + model_amax_count = len(model_amax_keys) |
| 173 | + |
| 174 | + # Ensure counts match |
| 175 | + if checkpoint_amax_count != model_amax_count: |
| 176 | + warnings.warn( |
| 177 | + f"Mismatch in amax key counts: checkpoint has {checkpoint_amax_count} " |
| 178 | + f"amax keys but model has {model_amax_count} amax keys. This can happen if the model is using PP." |
| 179 | + ) |
| 180 | + |
| 181 | + # Update amax values |
| 182 | + for key, value in saved_amax_dict.items(): |
| 183 | + if key in current_state_dict: |
| 184 | + current_state_dict[key] = value.to(current_state_dict[key].device) |
| 185 | + |
| 186 | + model.load_state_dict(current_state_dict) |
| 187 | + torch.distributed.barrier() |
| 188 | + |
| 189 | + if amax_file_path is None: |
| 190 | + # Sync amax across TP can be done here if needed |
| 191 | + pass |
| 192 | + # for name, buffer in model.named_buffers(): |
| 193 | + # if name.endswith("_amax"): |
| 194 | + # print("syncing amax across TP for", name) |
| 195 | + # torch.distributed.all_reduce( |
| 196 | + # buffer, op=torch.distributed.ReduceOp.MAX, group=get_tp_group().device_group |
| 197 | + # ) |
| 198 | + # torch.distributed.barrier() |
| 199 | + |
| 200 | + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: |
| 201 | + mtq.print_quant_summary(model) |
| 202 | + |
| 203 | + mtq.fold_weight(model) |
| 204 | + for name, module in model.named_modules(): |
| 205 | + if name.endswith("weight_quantizer"): |
| 206 | + assert not module.is_enabled, f"quantizer {name} is still enabled" |
| 207 | + |
| 208 | + |
| 209 | +class FakeQuantWorker(BaseWorker): |
| 210 | + @torch.inference_mode() |
| 211 | + def determine_available_memory(self) -> int: |
| 212 | + model = self.model_runner.model |
| 213 | + if hasattr(model, "unwrap"): |
| 214 | + model = model.unwrap() |
| 215 | + with disable_compilation(model): |
| 216 | + return super().determine_available_memory() |
| 217 | + |
| 218 | + def compile_or_warm_up_model(self) -> None: |
| 219 | + if quant_config["quant_cfg"]: |
| 220 | + _fakequant_run_prolog_worker(self) |
| 221 | + super().compile_or_warm_up_model() |
0 commit comments