Skip to content

Commit 0b0c3fc

Browse files
committed
chore: add nvfp4 quantization
1 parent 208df56 commit 0b0c3fc

File tree

3 files changed

+312
-227
lines changed

3 files changed

+312
-227
lines changed

tools/llm/quantize_utils.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
import json
2+
import logging
3+
import os
4+
5+
import huggingface_hub
6+
import torch
7+
from huggingface_hub import snapshot_download
8+
9+
logger = logging.getLogger(__name__)
10+
11+
try:
12+
import modelopt.torch.quantization as mtq # noqa: F401
13+
14+
assert torch.ops.tensorrt.quantize_op.default
15+
except Exception as e:
16+
logger.warning("Unable to import quantization op. Please install modelopt library")
17+
from modelopt.core.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor
18+
from modelopt.torch.quantization.config import QuantizerAttributeConfig
19+
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer
20+
from modelopt.torch.utils.dataset_utils import (
21+
create_forward_loop,
22+
get_dataset_dataloader,
23+
)
24+
from safetensors import safe_open
25+
26+
27+
def quantize_model(model, args, tokenizer):
28+
"""
29+
Quantize a PyTorch model using ModelOpt quantization.
30+
31+
This function performs post-training quantization (PTQ) on the model using
32+
calibration data from the provided tokenizer. It supports both FP8 and NVFP4
33+
quantization formats.
34+
35+
Args:
36+
model: PyTorch model to quantize
37+
args: Arguments containing quantization format and debug settings
38+
tokenizer: Tokenizer for creating calibration dataloader
39+
40+
Returns:
41+
Quantized model with reduced precision weights and activations
42+
43+
Raises:
44+
RuntimeError: If unsupported quantization format is specified
45+
"""
46+
# Create calibration dataloader for quantization
47+
calib_dataloader = get_dataset_dataloader(
48+
tokenizer=tokenizer,
49+
batch_size=32,
50+
num_samples=512,
51+
device="cuda:0",
52+
)
53+
if args.qformat == "fp8":
54+
quant_cfg = mtq.FP8_DEFAULT_CFG
55+
elif args.qformat == "nvfp4":
56+
quant_cfg = mtq.NVFP4_DEFAULT_CFG
57+
else:
58+
raise RuntimeError("Unsupported quantization format")
59+
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
60+
61+
model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
62+
if args.debug:
63+
mtq.print_quant_summary(model)
64+
65+
return model
66+
67+
68+
class TensorRTQuantizedLinear(torch.nn.Module):
69+
"""
70+
TensorRT quantized linear layer that applies quantization to both input and weight tensors.
71+
72+
This class implements a quantized linear layer that:
73+
1. Applies quantization to input tensor using TensorQuantizer
74+
2. Applies quantization to weight tensor using TensorQuantizer
75+
3. Performs linear operation with quantized tensors
76+
"""
77+
78+
def __init__(
79+
self, original_linear: torch.nn.Linear, input_amax, weight_amax, quant_cfg
80+
):
81+
"""
82+
Initialize quantized linear layer.
83+
84+
Args:
85+
original_linear: Original PyTorch linear layer to quantize
86+
input_amax: Maximum absolute value for input quantization scaling
87+
weight_amax: Maximum absolute value for weight quantization scaling
88+
quant_cfg: Quantization configuration for TensorQuantizer
89+
"""
90+
super().__init__()
91+
92+
# Store reference to original linear layer for weight access
93+
self.original_linear = original_linear
94+
95+
# Copy bias from original layer if it exists
96+
if original_linear.bias is not None:
97+
self.bias = torch.nn.Parameter(original_linear.bias.clone()).cuda()
98+
else:
99+
self.bias = None
100+
101+
# Create quantizers for input and weight tensors
102+
self.input_quantizer = TensorQuantizer(
103+
quant_attribute_cfg=quant_cfg, amax=input_amax
104+
)
105+
self.weight_quantizer = TensorQuantizer(
106+
quant_attribute_cfg=quant_cfg, amax=weight_amax
107+
)
108+
109+
def forward(self, input):
110+
input = self.input_quantizer(input)
111+
weight = self.weight_quantizer(self.original_linear.weight)
112+
return torch.nn.functional.linear(input, weight, self.bias)
113+
114+
115+
def convert_linear_to_tensorrt_quantized(model, model_name):
116+
"""
117+
Convert linear layers in a model to TensorRT quantized versions using pre-quantized weights.
118+
119+
This function is specifically designed for Hugging Face quantized models and only
120+
applies quantization to linear operations. It loads pre-quantized models from
121+
Hugging Face format and replaces standard linear layers with TensorRTQuantizedLinear
122+
layers. It supports both FP8 and NVFP4 quantization formats.
123+
124+
The function:
125+
1. Loads quantization scales from Hugging Face model files (SafeTensors)
126+
2. Parses quantization configuration from hf_quant_config.json
127+
3. Replaces standard linear layers with TensorRTQuantizedLinear layers
128+
4. Applies appropriate quantization based on the model's quantization format
129+
130+
Note: This function only quantizes linear operations and is intended for use
131+
with pre-quantized Hugging Face models that have been quantized using ModelOpt.
132+
133+
Args:
134+
model: PyTorch model to quantize
135+
model_name: Path to Hugging Face model directory or model identifier
136+
137+
Returns:
138+
Model with quantized linear layers
139+
140+
Raises:
141+
RuntimeError: If quantization config is not found or unsupported format
142+
"""
143+
# Determine if model_name is a local directory or needs to be downloaded
144+
if os.path.isdir(model_name):
145+
hf_folder = model_name
146+
else:
147+
# Download model from Hugging Face Hub
148+
hf_folder = snapshot_download(
149+
model_name,
150+
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
151+
ignore_patterns=["original/**/*"],
152+
revision=None,
153+
)
154+
155+
# Load all tensors from SafeTensors files
156+
tensors = {}
157+
for file in os.listdir(hf_folder):
158+
if file.endswith(".safetensors"):
159+
with safe_open(
160+
os.path.join(hf_folder, file), framework="pt", device="cpu"
161+
) as f:
162+
tensor_names = f.keys()
163+
for name in tensor_names:
164+
tensors[name] = f.get_tensor(name)
165+
166+
# Load and parse quantization configuration
167+
hf_quant_config_path = f"{hf_folder}/hf_quant_config.json"
168+
if os.path.exists(hf_quant_config_path):
169+
with open(hf_quant_config_path, "r") as f:
170+
hf_quant_config = json.load(f)
171+
hf_quant_config = hf_quant_config["quantization"]
172+
173+
hf_quant_algo = hf_quant_config.pop("quant_algo", None)
174+
if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4":
175+
raise RuntimeError("Only FP8 and NVFP4 quantization is supported")
176+
else:
177+
raise RuntimeError("No quantization config found")
178+
179+
# Iterate through all modules in the model
180+
for name, module in model.named_modules():
181+
# Check if the module is a linear layer
182+
target = torch.nn.modules.linear.Linear
183+
if isinstance(module, target):
184+
# Construct names for quantization scale tensors
185+
# These follow the naming convention: module_name.weight_scale and module_name.input_scale
186+
weight_scale_name = name + ".weight_scale"
187+
input_scale_name = name + ".input_scale"
188+
189+
# Verify that required scale tensors exist in the loaded data
190+
if weight_scale_name not in tensors:
191+
print(f"Weight scale tensor {weight_scale_name} not found")
192+
continue
193+
if input_scale_name not in tensors:
194+
print(f"Input scale tensor {input_scale_name} not found")
195+
continue
196+
197+
if hf_quant_algo == "FP8":
198+
# FP8 E4M3 format has a maximum representable value of 448.0
199+
# Scale the quantization parameters accordingly
200+
weight_scale = tensors.pop(weight_scale_name)
201+
weight_amax = weight_scale * 448.0
202+
input_amax = tensors.pop(input_scale_name) * 448.0
203+
204+
# Dequantize the weight using the scale factor
205+
dequantized_weight_data = module.weight.to(torch.float16) * weight_scale
206+
207+
# Configure quantizer for FP8 format (4 exponent bits, 3 mantissa bits)
208+
quantizer_attribute_config = QuantizerAttributeConfig(
209+
num_bits=(4, 3), axis=None
210+
)
211+
212+
elif hf_quant_algo == "NVFP4":
213+
# NVFP4 format requires additional scale tensor and different configuration
214+
weight_name = name + ".weight"
215+
weight_scale2_name = name + ".weight_scale_2"
216+
weight_scale = tensors.pop(weight_scale_name)
217+
input_scale = tensors.pop(input_scale_name)
218+
weight_scale2 = tensors.pop(weight_scale2_name)
219+
220+
# Calculate amax values with additional scaling factor for NVFP4
221+
input_amax = input_scale * 448.0 * 6.0
222+
weight_amax = weight_scale2 * 448.0 * 6.0
223+
224+
# Handle NVFP4 tensor format
225+
weight_data = tensors.pop(weight_name)
226+
original_shape = list(weight_data.shape)
227+
original_shape[-1] *= 2 # NVFP4 packs 2 values per element
228+
nvfp4_tensor = NVFP4QTensor(
229+
torch.Size(original_shape), torch.float16, weight_data
230+
)
231+
232+
# Dequantize using both scales and block size configuration
233+
dequantized_weight_data = nvfp4_tensor.dequantize(
234+
scale=weight_scale, double_scale=weight_scale2, block_sizes={-1: 16}
235+
)
236+
237+
# Configure quantizer for NVFP4 format with dynamic block quantization
238+
quantizer_attribute_config = QuantizerAttributeConfig(
239+
num_bits=(2, 1),
240+
axis=None,
241+
block_sizes={-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
242+
enable=True,
243+
)
244+
245+
# Apply dequantization to the original quantized weight using the scale
246+
# This ensures the weight is in the correct range for the quantized layer
247+
module.weight.data = dequantized_weight_data
248+
249+
# Create the quantized linear layer with calculated amax values
250+
quantized_module = TensorRTQuantizedLinear(
251+
module, input_amax, weight_amax, quantizer_attribute_config
252+
)
253+
254+
# Replace the original module with the quantized version
255+
# Extract parent module name and child module name
256+
parent_name = ".".join(name.split(".")[:-1])
257+
child_name = name.split(".")[-1]
258+
259+
if parent_name:
260+
# Get the parent module and replace the child
261+
parent_module = model.get_submodule(parent_name)
262+
setattr(parent_module, child_name, quantized_module)
263+
else:
264+
# If no parent, replace at model level
265+
setattr(model, child_name, quantized_module)
266+
267+
# Log any unused tensors for debugging
268+
if len(tensors) > 0:
269+
logger.debug(f"{len(tensors)} tensors not used")
270+
for key in tensors:
271+
logger.debug(f" {key}")
272+
return model

tools/llm/run_llm.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@
2222
from torchtrt_ext import register_sdpa
2323
from transformers import AutoModelForCausalLM, AutoTokenizer
2424
from utils import (
25-
convert_linear_to_tensorrt_quantized,
2625
export_llm,
2726
generate,
2827
generate_with_static_cache,
29-
quantize_model,
3028
record_stats,
3129
time_generate,
3230
)
@@ -58,12 +56,13 @@ def get_model(args):
5856
args.model,
5957
use_cache=False,
6058
attn_implementation="sdpa",
59+
ignore_mismatched_sizes=True,
6160
)
6261
.eval()
6362
.cuda()
6463
)
6564
if args.pre_quantized:
66-
model = convert_linear_to_tensorrt_quantized(model, args.model)
65+
model = convert_linear_to_tensorrt_quantized(model, args.model).cuda()
6766

6867
if args.precision == "FP16":
6968
model = model.to(torch.float16)
@@ -97,7 +96,8 @@ def compile_torchtrt(model, input_ids, args):
9796
for optimized inference
9897
"""
9998
max_seq_len = input_ids.shape[1] + args.num_tokens
100-
ep = export_llm(model, input_ids, max_seq_len=max_seq_len)
99+
with export_torch_mode() if args.qformat or args.pre_quantized else nullcontext():
100+
ep = export_llm(model, input_ids, max_seq_len=max_seq_len)
101101
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
102102
# Set precision specific flags
103103
use_fp32_acc = False
@@ -269,6 +269,14 @@ def measure_perf(trt_model, input_signature, backend_name):
269269
help="Use pre-quantized model weights (default: False)",
270270
)
271271
args = arg_parser.parse_args()
272+
273+
if args.qformat or args.pre_quantized:
274+
from modelopt.torch.quantization.utils import export_torch_mode
275+
from quantize_utils import (
276+
convert_linear_to_tensorrt_quantized,
277+
quantize_model,
278+
)
279+
272280
with torch.inference_mode():
273281
model = get_model(args)
274282

@@ -390,6 +398,8 @@ def measure_perf(trt_model, input_signature, backend_name):
390398
match_result = str(torch.equal(pyt_gen_tokens, trt_gen_tokens))
391399
out_json_file = f"{model_name}_{qformat}_match.json"
392400
result = {}
401+
args_dict = vars(args)
402+
result["args"] = args_dict
393403
result["match"] = match_result
394404
result["torch_out"] = torch_out
395405
result["trt_out"] = trt_out

0 commit comments

Comments
 (0)