Skip to content

Commit 303775e

Browse files
committed
chore: Detect pre-quantized hf model
1 parent 1f1cf7f commit 303775e

File tree

3 files changed

+71
-69
lines changed

3 files changed

+71
-69
lines changed

tools/llm/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is par
4040
- `--tokenizer`: (Optional) Tokenizer name; defaults to model.
4141
- `--prompt`: Input prompt for generation.
4242
- `--precision`: Precision mode (`FP16`, `FP32`).
43-
- `--qformat`: Quantization format (`fp8`, `nvfp4`) to apply.
44-
- `--pre_quantized`: Flag to use pre-quantized models from HuggingFace.
43+
- `--quant_format`: Quantization format (`fp8`, `nvfp4`) to apply.
4544
- `--num_tokens`: Number of output tokens to generate.
4645
- `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching).
4746
- `--benchmark`: Enable benchmarking mode.
@@ -56,15 +55,15 @@ Torch-TensorRT supports quantization to reduce model memory footprint and improv
5655
To use pre-quantized models from HuggingFace:
5756

5857
```bash
59-
python run_llm.py --model nvidia/Llama-3.1-8B-Instruct-FP8 --pre_quantized --prompt "What is parallel programming?" --precision FP16 --num_tokens 128
58+
python run_llm.py --model nvidia/Llama-3.1-8B-Instruct-FP8 --prompt "What is parallel programming?" --precision FP16 --num_tokens 128
6059
```
6160

6261
#### Applying quantization by ModelOpt
6362

6463
Apply fp8 quantization from HuggingFace:
6564

6665
```bash
67-
python run_llm.py --model meta-llama/Llama-3.1-8B --qformat fp8 --prompt "What is parallel programming?" --precision FP16 --num_tokens 128
66+
python run_llm.py --model meta-llama/Llama-3.1-8B --quant_format fp8 --prompt "What is parallel programming?" --precision FP16 --num_tokens 128
6867
```
6968

7069
#### Quantization Requirements

tools/llm/quantize_utils.py

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,22 @@
2727

2828
def quantize_model(model, args, tokenizer):
2929
"""
30-
Quantize a PyTorch model using ModelOpt quantization.
30+
Quantize a PyTorch model using ModelOpt post-training quantization (PTQ).
3131
32-
This function performs post-training quantization (PTQ) on the model using
33-
calibration data from the provided tokenizer. It supports both FP8 and NVFP4
34-
quantization formats.
32+
This function applies quantization to reduce model precision for faster inference
33+
while maintaining acceptable accuracy. It uses calibration data generated from
34+
the provided tokenizer to determine optimal quantization parameters.
3535
36+
Supported quantization formats:
37+
- fp8: 8-bit floating point quantization
38+
- nvfp4: 4-bit NVIDIA floating point quantization
3639
Args:
37-
model: PyTorch model to quantize
38-
args: Arguments containing quantization format and debug settings
39-
tokenizer: Tokenizer for creating calibration dataloader
40+
model: PyTorch model to quantize. Must be in evaluation mode.
41+
args: Command line arguments containing quant_format and debug
42+
tokenizer: Hugging Face tokenizer for creating calibration data
4043
4144
Returns:
42-
Quantized model with reduced precision weights and activations
43-
44-
Raises:
45-
RuntimeError: If unsupported quantization format is specified
45+
Quantized model
4646
"""
4747
# Create calibration dataloader for quantization
4848
calib_dataloader = get_dataset_dataloader(
@@ -51,9 +51,9 @@ def quantize_model(model, args, tokenizer):
5151
num_samples=512,
5252
device="cuda:0",
5353
)
54-
if args.qformat == "fp8":
54+
if args.quant_format == "fp8":
5555
quant_cfg = mtq.FP8_DEFAULT_CFG
56-
elif args.qformat == "nvfp4":
56+
elif args.quant_format == "nvfp4":
5757
quant_cfg = mtq.NVFP4_DEFAULT_CFG
5858
else:
5959
raise RuntimeError("Unsupported quantization format")
@@ -108,7 +108,38 @@ def forward(self, input):
108108
return torch.nn.functional.linear(input, weight, self.bias)
109109

110110

111-
def convert_linear_to_tensorrt_quantized(model, model_name):
111+
def load_quantization_config(model_name):
112+
"""
113+
Load quantization configuration from a Hugging Face model.
114+
Args:
115+
model_name (str): Local directory path or model identifier
116+
Returns:
117+
dict or None: Quantization configuration. None if no config found.
118+
"""
119+
# Determine if model_name is a local directory or needs to be downloaded
120+
if os.path.isdir(model_name):
121+
model_path = model_name
122+
else:
123+
# Download model from Hugging Face Hub
124+
model_path = snapshot_download(
125+
model_name,
126+
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
127+
ignore_patterns=["original/**/*"],
128+
revision=None,
129+
)
130+
hf_quant_config = None
131+
# Load and parse quantization configuration
132+
hf_quant_config_path = f"{model_path}/hf_quant_config.json"
133+
if os.path.exists(hf_quant_config_path):
134+
with open(hf_quant_config_path, "r") as f:
135+
hf_quant_config = json.load(f)
136+
hf_quant_config = hf_quant_config["quantization"]
137+
hf_quant_config["model_path"] = model_path
138+
139+
return hf_quant_config
140+
141+
142+
def convert_linear_to_tensorrt_quantized(model, hf_quant_config):
112143
"""
113144
Convert linear layers in a model to TensorRT quantized versions from pre-quantized weights.
114145
@@ -119,58 +150,37 @@ def convert_linear_to_tensorrt_quantized(model, model_name):
119150
120151
The function:
121152
1. Loads quantization scales from Hugging Face model files (SafeTensors)
122-
2. Parses quantization configuration from hf_quant_config.json
123-
3. Replaces standard linear layers with TensorRTQuantizedLinear layers
124-
4. Applies appropriate quantization based on the model's quantization format
153+
2. Replaces standard linear layers with TensorRTQuantizedLinear layers
154+
3. Applies appropriate quantization based on the model's quantization format
125155
126156
Note: This function only quantizes linear operations and is intended for use
127157
with pre-quantized Hugging Face models that have been quantized using ModelOpt.
128158
129159
Args:
130160
model: PyTorch model to quantize
131-
model_name: Path to Hugging Face model directory or model identifier
161+
hf_quant_config: Quantization configuration
132162
133163
Returns:
134164
Model with quantized linear layers
135165
136166
Raises:
137167
RuntimeError: If quantization config is not found or unsupported format
138168
"""
139-
# Determine if model_name is a local directory or needs to be downloaded
140-
if os.path.isdir(model_name):
141-
hf_folder = model_name
142-
else:
143-
# Download model from Hugging Face Hub
144-
hf_folder = snapshot_download(
145-
model_name,
146-
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
147-
ignore_patterns=["original/**/*"],
148-
revision=None,
149-
)
150-
169+
model_path = hf_quant_config["model_path"]
151170
# Load all tensors from SafeTensors files
152171
tensors = {}
153-
for file in os.listdir(hf_folder):
172+
for file in os.listdir(model_path):
154173
if file.endswith(".safetensors"):
155174
with safe_open(
156-
os.path.join(hf_folder, file), framework="pt", device="cpu"
175+
os.path.join(model_path, file), framework="pt", device="cpu"
157176
) as f:
158177
tensor_names = f.keys()
159178
for name in tensor_names:
160179
tensors[name] = f.get_tensor(name)
161180

162-
# Load and parse quantization configuration
163-
hf_quant_config_path = f"{hf_folder}/hf_quant_config.json"
164-
if os.path.exists(hf_quant_config_path):
165-
with open(hf_quant_config_path, "r") as f:
166-
hf_quant_config = json.load(f)
167-
hf_quant_config = hf_quant_config["quantization"]
168-
169-
hf_quant_algo = hf_quant_config.pop("quant_algo", None)
170-
if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4":
171-
raise RuntimeError("Only FP8 or NVFP4 quantization is supported")
172-
else:
173-
raise RuntimeError("No quantization config found")
181+
hf_quant_algo = hf_quant_config.get("quant_algo", None)
182+
if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4":
183+
raise RuntimeError("Only FP8 or NVFP4 quantization is supported")
174184

175185
# Iterate through all modules in the model
176186
for name, module in model.named_modules():

tools/llm/run_llm.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@
1919
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2020
import torch
2121
import torch_tensorrt
22+
from modelopt.torch.quantization.utils import export_torch_mode
23+
from quantize_utils import (
24+
convert_linear_to_tensorrt_quantized,
25+
load_quantization_config,
26+
quantize_model,
27+
)
2228
from torchtrt_ext import register_sdpa
2329
from transformers import AutoModelForCausalLM, AutoTokenizer
2430
from utils import (
@@ -60,8 +66,11 @@ def get_model(args):
6066
.eval()
6167
.cuda()
6268
)
63-
if args.pre_quantized:
64-
model = convert_linear_to_tensorrt_quantized(model, args.model).cuda()
69+
70+
hf_quant_config = load_quantization_config(args.model)
71+
if hf_quant_config:
72+
model = convert_linear_to_tensorrt_quantized(model, hf_quant_config).cuda()
73+
print(f"Model converted to TensorRT quantized")
6574

6675
if args.precision == "FP16":
6776
model = model.to(torch.float16)
@@ -95,7 +104,7 @@ def compile_torchtrt(model, input_ids, args):
95104
for optimized inference
96105
"""
97106
max_seq_len = input_ids.shape[1] + args.num_tokens
98-
with export_torch_mode() if args.qformat or args.pre_quantized else nullcontext():
107+
with export_torch_mode():
99108
ep = export_llm(model, input_ids, max_seq_len=max_seq_len)
100109
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
101110
# Set precision specific flags
@@ -240,28 +249,12 @@ def measure_perf(trt_model, input_signature, backend_name):
240249
"--benchmark", action="store_true", help="Enable benchmark (default: False)"
241250
)
242251
arg_parser.add_argument(
243-
"--qformat",
252+
"--quant_format",
244253
help=("Apply quantization format. Options: fp8, nvfp4 (default: None)"),
245254
default=None,
246255
)
247-
arg_parser.add_argument(
248-
"--pre_quantized",
249-
action="store_true",
250-
help="Use pre-quantized hf model weights (default: False)",
251-
)
252256
args = arg_parser.parse_args()
253257

254-
if args.qformat and args.pre_quantized:
255-
print("Error: --qformat and --pre_quantized cannot be used together")
256-
exit()
257-
258-
if args.qformat or args.pre_quantized:
259-
from modelopt.torch.quantization.utils import export_torch_mode
260-
from quantize_utils import (
261-
convert_linear_to_tensorrt_quantized,
262-
quantize_model,
263-
)
264-
265258
with torch.inference_mode():
266259
model = get_model(args)
267260

@@ -286,7 +279,7 @@ def measure_perf(trt_model, input_signature, backend_name):
286279
pyt_timings = None
287280
pyt_stats = None
288281

289-
if args.qformat != None:
282+
if args.quant_format != None:
290283
model = quantize_model(model, args, tokenizer)
291284
if args.enable_pytorch_run:
292285
pyt_gen_tokens = generate(

0 commit comments

Comments
 (0)