diff --git a/tools/llm/quantize_utils.py b/tools/llm/quantize_utils.py new file mode 100644 index 0000000000..28228be074 --- /dev/null +++ b/tools/llm/quantize_utils.py @@ -0,0 +1,266 @@ +import json +import logging +import os + +import huggingface_hub +import torch +from huggingface_hub import snapshot_download + +logger = logging.getLogger(__name__) + +try: + import modelopt.torch.quantization as mtq # noqa: F401 + + assert torch.ops.tensorrt.quantize_op.default +except Exception as e: + logger.warning("Unable to import quantization op. Please install modelopt library") +from modelopt.core.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor +from modelopt.torch.quantization.config import QuantizerAttributeConfig +from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer +from modelopt.torch.utils.dataset_utils import ( + create_forward_loop, + get_dataset_dataloader, +) +from safetensors import safe_open + + +def quantize_model(model, args, tokenizer): + """ + Quantize a PyTorch model using ModelOpt quantization. + + This function performs post-training quantization (PTQ) on the model using + calibration data from the provided tokenizer. It supports both FP8 and NVFP4 + quantization formats. + + Args: + model: PyTorch model to quantize + args: Arguments containing quantization format and debug settings + tokenizer: Tokenizer for creating calibration dataloader + + Returns: + Quantized model with reduced precision weights and activations + + Raises: + RuntimeError: If unsupported quantization format is specified + """ + # Create calibration dataloader for quantization + calib_dataloader = get_dataset_dataloader( + tokenizer=tokenizer, + batch_size=32, + num_samples=512, + device="cuda:0", + ) + if args.qformat == "fp8": + quant_cfg = mtq.FP8_DEFAULT_CFG + elif args.qformat == "nvfp4": + quant_cfg = mtq.NVFP4_DEFAULT_CFG + else: + raise RuntimeError("Unsupported quantization format") + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + + model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + if args.debug: + mtq.print_quant_summary(model) + + return model + + +class TensorRTQuantizedLinear(torch.nn.Module): + """ + TensorRT quantized linear layer that applies quantization to both input and weight tensors. + """ + + def __init__( + self, original_linear: torch.nn.Linear, input_amax, weight_amax, quant_cfg + ): + """ + Initialize quantized linear layer. + + Args: + original_linear: Original PyTorch linear layer to quantize + input_amax: Maximum absolute value for input quantization scaling + weight_amax: Maximum absolute value for weight quantization scaling + quant_cfg: Quantization configuration for TensorQuantizer + """ + super().__init__() + + # Store reference to original linear layer for weight access + self.original_linear = original_linear + + # Copy bias from original layer if it exists + if original_linear.bias is not None: + self.bias = torch.nn.Parameter(original_linear.bias.clone()).cuda() + else: + self.bias = None + + # Create quantizers for input and weight tensors + self.input_quantizer = TensorQuantizer( + quant_attribute_cfg=quant_cfg, amax=input_amax + ) + self.weight_quantizer = TensorQuantizer( + quant_attribute_cfg=quant_cfg, amax=weight_amax + ) + + def forward(self, input): + input = self.input_quantizer(input) + weight = self.weight_quantizer(self.original_linear.weight) + return torch.nn.functional.linear(input, weight, self.bias) + + +def convert_linear_to_tensorrt_quantized(model, model_name): + """ + Convert linear layers in a model to TensorRT quantized versions from pre-quantized weights. + + This function is specifically designed for Hugging Face quantized models and only + applies quantization to linear operations. It loads pre-quantized models from + Hugging Face format and replaces standard linear layers with TensorRTQuantizedLinear + layers. It supports both FP8 and NVFP4 quantization formats. + + The function: + 1. Loads quantization scales from Hugging Face model files (SafeTensors) + 2. Parses quantization configuration from hf_quant_config.json + 3. Replaces standard linear layers with TensorRTQuantizedLinear layers + 4. Applies appropriate quantization based on the model's quantization format + + Note: This function only quantizes linear operations and is intended for use + with pre-quantized Hugging Face models that have been quantized using ModelOpt. + + Args: + model: PyTorch model to quantize + model_name: Path to Hugging Face model directory or model identifier + + Returns: + Model with quantized linear layers + + Raises: + RuntimeError: If quantization config is not found or unsupported format + """ + # Determine if model_name is a local directory or needs to be downloaded + if os.path.isdir(model_name): + hf_folder = model_name + else: + # Download model from Hugging Face Hub + hf_folder = snapshot_download( + model_name, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ignore_patterns=["original/**/*"], + revision=None, + ) + + # Load all tensors from SafeTensors files + tensors = {} + for file in os.listdir(hf_folder): + if file.endswith(".safetensors"): + with safe_open( + os.path.join(hf_folder, file), framework="pt", device="cpu" + ) as f: + tensor_names = f.keys() + for name in tensor_names: + tensors[name] = f.get_tensor(name) + + # Load and parse quantization configuration + hf_quant_config_path = f"{hf_folder}/hf_quant_config.json" + if os.path.exists(hf_quant_config_path): + with open(hf_quant_config_path, "r") as f: + hf_quant_config = json.load(f) + hf_quant_config = hf_quant_config["quantization"] + + hf_quant_algo = hf_quant_config.pop("quant_algo", None) + if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4": + raise RuntimeError("Only FP8 or NVFP4 quantization is supported") + else: + raise RuntimeError("No quantization config found") + + # Iterate through all modules in the model + for name, module in model.named_modules(): + # Check if the module is a linear layer + target = torch.nn.modules.linear.Linear + if isinstance(module, target): + # Construct names for quantization scale tensors + # These follow the naming convention: module_name.weight_scale and module_name.input_scale + weight_scale_name = name + ".weight_scale" + input_scale_name = name + ".input_scale" + + if weight_scale_name not in tensors: + print(f"Weight scale tensor {weight_scale_name} not found") + continue + if input_scale_name not in tensors: + print(f"Input scale tensor {input_scale_name} not found") + continue + + if hf_quant_algo == "FP8": + # FP8 E4M3 format has a maximum representable value of 448.0 + # Scale the quantization parameters accordingly + weight_scale = tensors.pop(weight_scale_name) + weight_amax = weight_scale * 448.0 + input_amax = tensors.pop(input_scale_name) * 448.0 + + # Dequantize the weight using the scale factor + dequantized_weight_data = module.weight.to(torch.float32) * weight_scale + + # Configure quantizer for FP8 format (4 exponent bits, 3 mantissa bits) + quantizer_attribute_config = QuantizerAttributeConfig( + num_bits=(4, 3), axis=None + ) + + elif hf_quant_algo == "NVFP4": + # NVFP4 format requires additional scale tensor and different configuration + weight_name = name + ".weight" + weight_scale2_name = name + ".weight_scale_2" + weight_scale = tensors.pop(weight_scale_name) + input_scale = tensors.pop(input_scale_name) + weight_scale2 = tensors.pop(weight_scale2_name) + + # Calculate amax values with additional scaling factor for NVFP4 + input_amax = input_scale * 448.0 * 6.0 + weight_amax = weight_scale2 * 448.0 * 6.0 + + # Handle NVFP4 tensor format + weight_data = tensors.pop(weight_name) + original_shape = list(weight_data.shape) + original_shape[-1] *= 2 # NVFP4 packs 2 values per element + nvfp4_tensor = NVFP4QTensor( + torch.Size(original_shape), torch.float32, weight_data + ) + + # Dequantize using both scales and block size configuration + dequantized_weight_data = nvfp4_tensor.dequantize( + scale=weight_scale, double_scale=weight_scale2, block_sizes={-1: 16} + ) + + # Configure quantizer for NVFP4 format with dynamic block quantization + quantizer_attribute_config = QuantizerAttributeConfig( + num_bits=(2, 1), + axis=None, + block_sizes={-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + enable=True, + ) + + # Restore the weight to its original full-precision format so that QDQ nodes + # can be properly inserted and optimized during TensorRT compilation + module.weight.data = dequantized_weight_data + + # Create the quantized linear layer with calculated amax values + quantized_module = TensorRTQuantizedLinear( + module, input_amax, weight_amax, quantizer_attribute_config + ) + + # Replace the original module with the quantized version + # Extract parent module name and child module name + parent_name = ".".join(name.split(".")[:-1]) + child_name = name.split(".")[-1] + + if parent_name: + # Get the parent module and replace the child + parent_module = model.get_submodule(parent_name) + setattr(parent_module, child_name, quantized_module) + else: + # If no parent, replace at model level + setattr(model, child_name, quantized_module) + + # Log any unused tensors for debugging + if len(tensors) > 0: + logger.debug(f"{len(tensors)} tensors not used") + for key in tensors: + logger.debug(f" {key}") + return model diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py index 7e50b515c2..56291d0ecf 100644 --- a/tools/llm/run_llm.py +++ b/tools/llm/run_llm.py @@ -9,6 +9,7 @@ import argparse import copy +import json import os import timeit from contextlib import nullcontext @@ -54,10 +55,13 @@ def get_model(args): args.model, use_cache=False, attn_implementation="sdpa", + ignore_mismatched_sizes=True, ) .eval() .cuda() ) + if args.pre_quantized: + model = convert_linear_to_tensorrt_quantized(model, args.model).cuda() if args.precision == "FP16": model = model.to(torch.float16) @@ -91,7 +95,8 @@ def compile_torchtrt(model, input_ids, args): for optimized inference """ max_seq_len = input_ids.shape[1] + args.num_tokens - ep = export_llm(model, input_ids, max_seq_len=max_seq_len) + with export_torch_mode() if args.qformat or args.pre_quantized else nullcontext(): + ep = export_llm(model, input_ids, max_seq_len=max_seq_len) position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) # Set precision specific flags use_fp32_acc = False @@ -234,13 +239,32 @@ def measure_perf(trt_model, input_signature, backend_name): arg_parser.add_argument( "--benchmark", action="store_true", help="Enable benchmark (default: False)" ) - + arg_parser.add_argument( + "--qformat", + help=("Apply quantization format. Options: fp8, nvfp4 (default: None)"), + default=None, + ) + arg_parser.add_argument( + "--pre_quantized", + action="store_true", + help="Use pre-quantized hf model weights (default: False)", + ) args = arg_parser.parse_args() + + if args.qformat or args.pre_quantized: + from modelopt.torch.quantization.utils import export_torch_mode + from quantize_utils import ( + convert_linear_to_tensorrt_quantized, + quantize_model, + ) + with torch.inference_mode(): model = get_model(args) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) - + # Set pad token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token # Prepare input for benchmarking or evaluation if args.benchmark: input_ids = torch.randint( @@ -258,6 +282,8 @@ def measure_perf(trt_model, input_signature, backend_name): pyt_timings = None pyt_stats = None + if args.qformat != None: + model = quantize_model(model, args, tokenizer) if args.enable_pytorch_run: pyt_gen_tokens = generate( model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id diff --git a/tools/llm/utils.py b/tools/llm/utils.py index 2c3434b0ed..90561ef9f7 100644 --- a/tools/llm/utils.py +++ b/tools/llm/utils.py @@ -1,4 +1,3 @@ -import copy import timeit import numpy as np