diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index f300bd378..1b5581604 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1111,6 +1111,39 @@ class SVDQuantConfig(QuantizeAlgorithmConfig): ) +class GPTQLiteConfig(QuantizeAlgorithmConfig): + """The config for GPTQ lite. + + GPTQ lite is a variant of GPTQ that does not exactly follow the official GPTQ implementation. + + GPTQ lite does not perform sequential quantization of layers. This means that the updated + activations are not used to process the next layer. + + GPTQ lite also uses dynamic scales computed during the weight update phase. The original GPTQ + implementation uses static scales computed on the weights before beginning blockwise update. + + """ + + method: Literal["gptq_lite"] = ModeloptField("gptq_lite") + percdamp: float | None = ModeloptField( + default=0.01, + gt=0.0, + le=1.0, + title="Percentage damping factor.", + description="The percentage of average Hessian diagonal used for damping.", + ) + block_size: int | None = ModeloptField( + default=128, + title="Block size for GPTQ weight update.", + description="The block size for GPTQ weight update.", + ) + hessian_state_path: str | None = ModeloptField( + default=None, + title="Path to the Hessian state file.", + description="The path to the Hessian state file.", + ) + + QuantizeQuantCfgType = dict[ str | Callable, QuantizerAttributeConfig diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 4e6e9fd49..e4adde2a6 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -37,6 +37,7 @@ AWQFullCalibConfig, AWQLiteCalibConfig, CompressConfig, + GPTQLiteConfig, MaxCalibConfig, QuantizeAlgoCfgType, QuantizeAlgorithmConfig, @@ -54,7 +55,7 @@ restore_svdquant_model, update_quantize_metadata, ) -from .model_calib import awq, max_calibrate, smoothquant, svdquant +from .model_calib import awq, gptq_lite, max_calibrate, smoothquant, svdquant __all__ = ["BaseCalibrateModeDescriptor"] @@ -426,3 +427,15 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]: def restore(self) -> RestoreEntrypoint: """The mode's entrypoint for restoring a model.""" return restore_svdquant_model + + +@CalibrateModeRegistry.register_mode +class GPTQLiteModeDescriptor(BaseCalibrateModeDescriptor): + """Mode for GPTQ calibration algorithm.""" + + @property + def config_class(self) -> type[QuantizeAlgorithmConfig]: + """Specifies the config class for the mode.""" + return GPTQLiteConfig + + _calib_func = gptq_lite diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index c1e6feb06..3c7bd9136 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -15,7 +15,9 @@ """Calibration utilities.""" +import gc import math +import os import warnings from functools import partial @@ -23,11 +25,13 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +from tqdm import tqdm from modelopt.torch.opt.searcher import ForwardLoop from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method +from modelopt.torch.utils.perf import get_gpu_mem_fraction from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context from .nn import QuantModule, SequentialQuantizer, TensorQuantizer @@ -953,3 +957,322 @@ def postprocess(module, name): with enable_weight_access_and_writeback(module, model): postprocess(module, name) max_calibrate(model, forward_loop) + + +def update_hessian(input, hessian, n_samples): + """Update hessian matrix with new input samples using incremental formula. + + Args: + input: Input tensor (batch_size, ..., features) + hessian: Current Hessian matrix to update in-place + n_samples: Number of samples already processed + + Returns: + Tuple of (updated_hessian, new_sample_count) + """ + batch_size = input.shape[0] + + # Incremental averaging: scale down old hessian + hessian *= n_samples / (n_samples + batch_size) + n_samples += batch_size + + # Compute outer product: H += (X^T X) / n_samples + input_flat = input.reshape(-1, input.shape[-1]).t().float() + scaled_input = math.sqrt(2 / n_samples) * input_flat + hessian.add_((scaled_input @ scaled_input.t()).to(hessian.device)) + + return hessian, n_samples + + +def prepare_hessian_inverse(h, weight, percdamp): + """Prepare inverse Hessian with dead neuron handling and damping.""" + h = h.clone() + + # Handle dead neurons (zero diagonal elements) + dead_mask = torch.diag(h) == 0 + h[dead_mask, dead_mask] = 1 + weight[:, dead_mask] = 0 + + # Add damping to diagonal + damp = percdamp * torch.mean(torch.diag(h)) + diag_indices = torch.arange(h.shape[0], device=h.device) + h[diag_indices, diag_indices] += damp + + try: + h = torch.cholesky_inverse(torch.linalg.cholesky(h)) + h_inv = torch.linalg.cholesky(h, upper=True) + except (RuntimeError, torch.linalg.LinAlgError): + print("Warning: Hessian is not positive definite, using identity matrix") + h_inv = torch.eye(h.shape[0], device=h.device, dtype=h.dtype) + return h_inv + + +def quantize_block(full_weight, block_start, block_end, h_inv, quantizer): + """Quantize a block of weights group by group (based on quantizer block sizes) with error propagation. + + Args: + full_weight: The complete weight tensor (needed for INT4 quantization) + block_start: Starting column index of the block + block_end: Ending column index of the block + h_inv: Hessian inverse + quantizer: The quantizer to apply + + Returns: + quantized_block: Quantized weights for this block + losses: Quantization losses per element + errors: Accumulated errors for propagation + """ + # Extract the block we're working on + block_weight = full_weight[:, block_start:block_end].clone() + block_hinv = h_inv[block_start:block_end, block_start:block_end] + block_size = block_end - block_start + + quantized_block = torch.zeros_like(block_weight) + losses = torch.zeros_like(block_weight) + errors = torch.zeros_like(block_weight) + if getattr(quantizer, "block_sizes", None) is not None: + group_size = quantizer.block_sizes[-1] + + if group_size is None: + warnings.warn("Block sizes not found in quantizer, using group size of 1") + group_size = 1 + + assert block_size % group_size == 0, "Block size must be divisible by group size" + + for group_start in range(0, block_size, group_size): + group_end = min(group_start + group_size, block_size) + group_cols = slice(group_start, group_end) + # Get current column and its Hessian inverse diagonal + weight_col = block_weight[:, group_cols] + hinv_diag = torch.diag(block_hinv[group_cols, group_cols]) + + # Quantize using the full weight, then extract the columns we need + quantized_full = quantizer(full_weight) + quantized_cols = quantized_full[:, block_start + group_start : block_start + group_end] + quantized_block[:, group_cols] = quantized_cols + + # Compute quantization error and loss + error = (weight_col - quantized_cols) / hinv_diag + losses[:, group_cols] = (weight_col - quantized_cols) ** 2 / (hinv_diag**2) / 2 + errors[:, group_cols] = error + + # Propagate error to remaining columns in block + block_weight[:, group_start:] -= error @ block_hinv[group_start:group_end, group_start:] + + return quantized_block, losses, errors + + +def print_relative_mse_error(q: torch.Tensor, w: torch.Tensor, h: torch.Tensor, module_name: str): + """Print relative mean squared error between quantized and original weights. + + Computes the Hessian-weighted relative MSE between quantized and original weights, + providing a measure of quantization quality. This metric is adapted from the GPTQ + repository. + + Args: + q (torch.Tensor): Quantized weight tensor + w (torch.Tensor): Original weight tensor + h (torch.Tensor): Hessian matrix used for weighting the error + module_name (str): Name of the module for logging purposes + + Note: + Implementation adapted from the GPTQ repository: + https://github.com/IST-DASLab/FP-Quant + """ + delta = q - w + mse = (delta).mm(h).mul(delta).mean() / (w.mm(h).mul(w).mean() + 1e-6) + print(f"[{module_name}] Relative MSE error: {mse.item():.2e}") + + +def blockwise_weight_update(module, h, block_size, percdamp): + """Update module weights using GPTQ-style blockwise quantization. + + Args: + module: Neural network module with weight and weight_quantizer + H: Hessian matrix (d x d) + block_size: Size of blocks to process at once + percdamp: Damping percentage for Hessian diagonal + """ + weight = module.weight.data.clone() + _, num_cols = weight.shape + + # Preprocess Hessian: handle dead neurons and add damping + h_inv = prepare_hessian_inverse(h, weight, percdamp).to(weight.dtype) + + # Initialize output tensors + quantized_weight = torch.zeros_like(weight) + losses = torch.zeros_like(weight) + + # Process weights in blocks + for block_start in range(0, num_cols, block_size): + block_end = min(block_start + block_size, num_cols) + + quantized_block, block_losses, block_errors = quantize_block( + weight, block_start, block_end, h_inv, module.weight_quantizer + ) + # Store results + quantized_weight[:, block_start:block_end] = quantized_block + losses[:, block_start:block_end] = block_losses + + # Propagate errors to remaining weights + weight[:, block_end:] -= block_errors @ h_inv[block_start:block_end, block_end:] + + # Print relative mse error + print_relative_mse_error( + quantized_weight, module.weight, h.to(module.weight.dtype), module.name + ) + # Update module weights + module.weight.data = quantized_weight.reshape(module.weight.shape).to(module.weight.data.dtype) + + +def gptq_lite( + model: nn.Module, + forward_loop: ForwardLoop | None = None, + percdamp: float = 0.01, + block_size: int = 128, + hessian_state_path: str | None = None, +): + """GPTQ-lite quantization - a simplified GPTQ variant. + + Key differences from GPTQ: + - Layers are quantized in parallel (not sequentially with updated activations) + - Uses group-wise updates instead of column-wise updates + + Args: + model: Model to be calibrated. + forward_loop: Callable that forwards calibration data through the model. + percdamp: Percentage of avg Hessian diagonal for damping (default: 0.01). + block_size: Block size for GPTQ weight update. + hessian_state_path: Path to save/load Hessian state. If None, compute without saving. + If path exists, load from it. If path doesn't exist, compute and save to it. + + See :class:`GPTQLiteConfig ` for + details on the remaining arguments. + """ + # Dictionary to store hessian matrices: {layer_name: {"hessian": Tensor, "n_samples": int}} + hessian_state = {} + + def initialize_hessian_state(tensor_mapping): + """Initialize hessian state with zeros.""" + for name, (shape, device) in tensor_mapping.items(): + # Use CPU if GPU memory is tight + target_device = "cpu" if get_gpu_mem_fraction(device) > 0.65 else device + hessian_state[name] = { + "hessian": torch.zeros(shape, dtype=torch.float32, device=target_device), + "n_samples": 0, + } + + def load_hessian_state(path, tensor_mapping): + """Load hessian state from file.""" + print_rank_0(f"Loading hessian state from {path}") + loaded_state = torch.load(path, map_location="cpu") + + for name, (shape, device) in tensor_mapping.items(): + if name not in loaded_state: + raise KeyError(f"Layer '{name}' not found in loaded hessian state") + + # Move to appropriate device based on memory + target_device = "cpu" if get_gpu_mem_fraction(device) > 0.65 else device + hessian_state[name] = { + "hessian": loaded_state[name]["hessian"].to(target_device), + "n_samples": loaded_state[name]["n_samples"], + } + + print_rank_0(f"Successfully loaded hessian state with {len(hessian_state)} layers") + + def save_hessian_state(path): + """Save hessian state to file.""" + print_rank_0(f"Saving hessian state to {path}") + try: + # Move to CPU for saving + cpu_state = { + name: {"hessian": state["hessian"].cpu(), "n_samples": state["n_samples"]} + for name, state in hessian_state.items() + } + + os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) + torch.save(cpu_state, path) + print_rank_0(f"Successfully saved hessian state to {path}") + except Exception as e: + print_rank_0(f"Error saving hessian state: {e}") + print_rank_0("Continuing execution...") + + def hessian_hook(module, input, output): + """Hook to intercept activations and update hessian matrix.""" + state = hessian_state[module.name] + hessian, n_samples = update_hessian(input[0], state["hessian"], state["n_samples"]) + hessian_state[module.name] = {"hessian": hessian, "n_samples": n_samples} + torch.cuda.empty_cache() + gc.collect() + + # Phase 1: Collect statistics for quantizers + enable_stats_collection(model) + max_calibrate(model, forward_loop) + finish_stats_collection(model) + + # Phase 2: Build tensor mapping for all quantized layers + tensor_mapping = {} + for name, module in model.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + in_features = module.weight.shape[1] + tensor_mapping[name] = ((in_features, in_features), module.weight.device) + module.name = name # Attach name for easy access in hooks + + print_rank_0(f"Found {len(tensor_mapping)} quantized layers") + + # Phase 3: Load or compute Hessians + hessian_exists = hessian_state_path is not None and os.path.exists(hessian_state_path) + + if hessian_exists: + load_hessian_state(hessian_state_path, tensor_mapping) + else: + if forward_loop is None: + raise ValueError("forward_loop must be provided when computing Hessians") + + # Initialize hessian state + initialize_hessian_state(tensor_mapping) + + # Register hooks to collect activations + handles = [] + for name, module in model.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + handles.append(module.register_forward_hook(hessian_hook)) + + # Run forward loop to compute hessians + print_rank_0("Computing Hessian matrices...") + forward_loop(model) + + # Remove hooks + for handle in handles: + handle.remove() + + # Save if path provided + if hessian_state_path is not None: + save_hessian_state(hessian_state_path) + + # Phase 4: Update weights using computed Hessians + print_rank_0("Updating weights using GPTQ-lite algorithm...") + + quantized_modules = [ + (name, module) + for name, module in model.named_modules() + if is_quantized_linear(module) and module.weight_quantizer.is_enabled + ] + + for name, module in tqdm(quantized_modules, desc="Quantizing layers"): + state = hessian_state[module.name] + hessian = state["hessian"].to(module.weight.device) + blockwise_weight_update(module, hessian, block_size, percdamp) + torch.cuda.empty_cache() + + # Phase 5: Reset and recalibrate quantizer statistics + for name, module in model.named_modules(): + if is_quantized_linear(module) and module.weight_quantizer.is_enabled: + module.input_quantizer.reset_amax() + module.output_quantizer.reset_amax() + + enable_stats_collection(model) + max_calibrate(model, forward_loop) + finish_stats_collection(model) + + print_rank_0("GPTQ-lite quantization completed successfully") diff --git a/modelopt/torch/utils/perf.py b/modelopt/torch/utils/perf.py index cd2652f94..0f460566f 100644 --- a/modelopt/torch/utils/perf.py +++ b/modelopt/torch/utils/perf.py @@ -28,6 +28,7 @@ "Timer", "clear_cuda_cache", "get_cuda_memory_stats", + "get_gpu_mem_fraction", "report_memory", ] @@ -48,6 +49,23 @@ def get_cuda_memory_stats(device=None): } +def get_gpu_mem_fraction(device="cuda:0"): + """Get used GPU memory as a fraction of total memory. + + Args: + device: Device identifier (default: "cuda:0") + + Returns: + float: Fraction of GPU memory currently used (0.0 to 1.0). + Returns 0.0 if CUDA is not available. + """ + if not torch.cuda.is_available(): + return 0.0 + + free_memory, total_memory = torch.cuda.mem_get_info(device) + return (total_memory - free_memory) / total_memory + + def report_memory(name="", rank=0, device=None): """Simple GPU memory report.""" memory_stats = get_cuda_memory_stats(device) diff --git a/tests/gpu/torch/quantization/test_gptq.py b/tests/gpu/torch/quantization/test_gptq.py new file mode 100644 index 000000000..c3e828a58 --- /dev/null +++ b/tests/gpu/torch/quantization/test_gptq.py @@ -0,0 +1,171 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.model_calib import blockwise_weight_update, update_hessian +from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader + +RAND_SEED = 42 +torch.manual_seed(RAND_SEED) + + +@pytest.mark.parametrize( + ("block_size", "dim", "model_weight", "expect_weight_change"), + [ + (4, 16, torch.randn(16, 16).to("cuda"), True), # random weight + ( + 4, + 16, + torch.ones(16, 16).to("cuda"), + False, + ), # all same weight -> no quantization error -> no GPTQ update + ( + 4, + 32, + torch.tensor( + [ + 0, + 0.5, + 1, + -0.5, + 0, + 0.5, + 1, + -0.5, + 0, + 0.5, + 1, + -0.5, + 0, + 0.5, + 1, + -0.5, + -4, + -2, + 0, + 6, + -6, + -4, + -2, + 0, + -4, + -2, + 0, + 6, + -6, + -4, + -2, + 0, + ] + ) + .to("cuda") + .expand(32, -1), + False, + ), # weights with nvfp4 values -> no GPTQ update + ], +) +def test_gptq_updates(block_size, dim, model_weight, expect_weight_change): + model = torch.nn.Linear(dim, 1).to("cuda") + model.weight.data = model_weight + original_weight = model_weight.clone() + input = torch.randn(2, 16, dim).to("cuda") + hessian = torch.zeros(dim, dim).to("cpu") + n_samples = 0 + quant_cfg = mtq.NVFP4_DEFAULT_CFG + + mtq.quantize(model, quant_cfg, forward_loop=lambda model: model(input)) + + # Get qdq weight + q_dq_weight = model.weight_quantizer(model.weight.data) + + # Restore original weight + model.weight.data = original_weight.clone() + + hessian, n_samples = update_hessian(input, hessian, n_samples) + + # Verify n_samples is update using hessian matrix + assert n_samples == input.shape[0], "n_samples should be equal to input.shape[0]" + + # Perform another forward pass to update hessian matrix + input_2 = torch.randn(3, 16, dim).to("cuda") + hessian, n_samples = update_hessian(input_2, hessian, n_samples) + assert n_samples == input.shape[0] + input_2.shape[0], ( + "n_samples should be equal to input.shape[0] + input_2.shape[0]" + ) + + hessian = hessian.to(input.device) + blockwise_weight_update(model, hessian, block_size, 0.1) + if expect_weight_change: + # Weight must change as GPTQ updates weights to adjust for quantization error + assert not torch.allclose(model.weight.data, q_dq_weight), "Weight should not be equal" + else: + assert torch.allclose(model.weight.data, q_dq_weight), "Weight should be equal" + + +@pytest.mark.parametrize( + "quant_cfg", [mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG] +) +def test_gptq_e2e_flow(quant_cfg): + model = AutoModelForCausalLM.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto" + ) + tokenizer = AutoTokenizer.from_pretrained( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", trust_remote_code=True + ) + + # can't set attribute 'pad_token' for "" + # We skip this step for Nemo models + if tokenizer.pad_token != "" or tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Left padding usually provides better calibration result. + tokenizer.padding_side = "left" + + assert tokenizer.pad_token is not None, "Pad token cannot be set!" + model.eval() + + quant_cfg["algorithm"] = "gptq_lite" + # Define quantizer/dataloader + calib_dataloader = get_dataset_dataloader( + dataset_name="cnn_dailymail", + tokenizer=tokenizer, + batch_size=32, + num_samples=512, + device="cuda", + include_labels=False, + ) + # Only run single sample for preview + prompt = "Where is New York city?" + input_ids = tokenizer(prompt, return_tensors="pt") + print(f"Input ids: {input_ids}") + generated_ids_before_ptq = model.generate( + input_ids["input_ids"].to("cuda"), max_new_tokens=100, do_sample=False, temperature=0.0 + ) + + print( + f"Generated ids before quantization: {tokenizer.decode(generated_ids_before_ptq[0], skip_special_tokens=True)}" + ) + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + generated_ids_after_ptq = model.generate( + input_ids["input_ids"].to("cuda"), max_new_tokens=100, do_sample=False, temperature=0.0 + ) + print( + f"Generated ids after quantization: {tokenizer.decode(generated_ids_after_ptq[0], skip_special_tokens=True)}" + )