27
27
28
28
def quantize_model (model , args , tokenizer ):
29
29
"""
30
- Quantize a PyTorch model using ModelOpt quantization.
30
+ Quantize a PyTorch model using ModelOpt post-training quantization (PTQ) .
31
31
32
- This function performs post-training quantization (PTQ) on the model using
33
- calibration data from the provided tokenizer . It supports both FP8 and NVFP4
34
- quantization formats .
32
+ This function applies quantization to reduce model precision for faster inference
33
+ while maintaining acceptable accuracy . It uses calibration data generated from
34
+ the provided tokenizer to determine optimal quantization parameters .
35
35
36
+ Supported quantization formats:
37
+ - fp8: 8-bit floating point quantization
38
+ - nvfp4: 4-bit NVIDIA floating point quantization
36
39
Args:
37
- model: PyTorch model to quantize
38
- args: Arguments containing quantization format and debug settings
39
- tokenizer: Tokenizer for creating calibration dataloader
40
+ model: PyTorch model to quantize. Must be in evaluation mode.
41
+ args: Command line arguments containing quant_format and debug
42
+ tokenizer: Hugging Face tokenizer for creating calibration data
40
43
41
44
Returns:
42
- Quantized model with reduced precision weights and activations
43
-
44
- Raises:
45
- RuntimeError: If unsupported quantization format is specified
45
+ Quantized model
46
46
"""
47
47
# Create calibration dataloader for quantization
48
48
calib_dataloader = get_dataset_dataloader (
@@ -51,9 +51,9 @@ def quantize_model(model, args, tokenizer):
51
51
num_samples = 512 ,
52
52
device = "cuda:0" ,
53
53
)
54
- if args .qformat == "fp8" :
54
+ if args .quant_format == "fp8" :
55
55
quant_cfg = mtq .FP8_DEFAULT_CFG
56
- elif args .qformat == "nvfp4" :
56
+ elif args .quant_format == "nvfp4" :
57
57
quant_cfg = mtq .NVFP4_DEFAULT_CFG
58
58
else :
59
59
raise RuntimeError ("Unsupported quantization format" )
@@ -108,7 +108,38 @@ def forward(self, input):
108
108
return torch .nn .functional .linear (input , weight , self .bias )
109
109
110
110
111
- def convert_linear_to_tensorrt_quantized (model , model_name ):
111
+ def load_quantization_config (model_name ):
112
+ """
113
+ Load quantization configuration from a Hugging Face model.
114
+ Args:
115
+ model_name (str): Local directory path or model identifier
116
+ Returns:
117
+ dict or None: Quantization configuration. None if no config found.
118
+ """
119
+ # Determine if model_name is a local directory or needs to be downloaded
120
+ if os .path .isdir (model_name ):
121
+ model_path = model_name
122
+ else :
123
+ # Download model from Hugging Face Hub
124
+ model_path = snapshot_download (
125
+ model_name ,
126
+ local_files_only = huggingface_hub .constants .HF_HUB_OFFLINE ,
127
+ ignore_patterns = ["original/**/*" ],
128
+ revision = None ,
129
+ )
130
+ hf_quant_config = None
131
+ # Load and parse quantization configuration
132
+ hf_quant_config_path = f"{ model_path } /hf_quant_config.json"
133
+ if os .path .exists (hf_quant_config_path ):
134
+ with open (hf_quant_config_path , "r" ) as f :
135
+ hf_quant_config = json .load (f )
136
+ hf_quant_config = hf_quant_config ["quantization" ]
137
+ hf_quant_config ["model_path" ] = model_path
138
+
139
+ return hf_quant_config
140
+
141
+
142
+ def convert_linear_to_tensorrt_quantized (model , hf_quant_config ):
112
143
"""
113
144
Convert linear layers in a model to TensorRT quantized versions from pre-quantized weights.
114
145
@@ -119,58 +150,37 @@ def convert_linear_to_tensorrt_quantized(model, model_name):
119
150
120
151
The function:
121
152
1. Loads quantization scales from Hugging Face model files (SafeTensors)
122
- 2. Parses quantization configuration from hf_quant_config.json
123
- 3. Replaces standard linear layers with TensorRTQuantizedLinear layers
124
- 4. Applies appropriate quantization based on the model's quantization format
153
+ 2. Replaces standard linear layers with TensorRTQuantizedLinear layers
154
+ 3. Applies appropriate quantization based on the model's quantization format
125
155
126
156
Note: This function only quantizes linear operations and is intended for use
127
157
with pre-quantized Hugging Face models that have been quantized using ModelOpt.
128
158
129
159
Args:
130
160
model: PyTorch model to quantize
131
- model_name: Path to Hugging Face model directory or model identifier
161
+ hf_quant_config: Quantization configuration
132
162
133
163
Returns:
134
164
Model with quantized linear layers
135
165
136
166
Raises:
137
167
RuntimeError: If quantization config is not found or unsupported format
138
168
"""
139
- # Determine if model_name is a local directory or needs to be downloaded
140
- if os .path .isdir (model_name ):
141
- hf_folder = model_name
142
- else :
143
- # Download model from Hugging Face Hub
144
- hf_folder = snapshot_download (
145
- model_name ,
146
- local_files_only = huggingface_hub .constants .HF_HUB_OFFLINE ,
147
- ignore_patterns = ["original/**/*" ],
148
- revision = None ,
149
- )
150
-
169
+ model_path = hf_quant_config ["model_path" ]
151
170
# Load all tensors from SafeTensors files
152
171
tensors = {}
153
- for file in os .listdir (hf_folder ):
172
+ for file in os .listdir (model_path ):
154
173
if file .endswith (".safetensors" ):
155
174
with safe_open (
156
- os .path .join (hf_folder , file ), framework = "pt" , device = "cpu"
175
+ os .path .join (model_path , file ), framework = "pt" , device = "cpu"
157
176
) as f :
158
177
tensor_names = f .keys ()
159
178
for name in tensor_names :
160
179
tensors [name ] = f .get_tensor (name )
161
180
162
- # Load and parse quantization configuration
163
- hf_quant_config_path = f"{ hf_folder } /hf_quant_config.json"
164
- if os .path .exists (hf_quant_config_path ):
165
- with open (hf_quant_config_path , "r" ) as f :
166
- hf_quant_config = json .load (f )
167
- hf_quant_config = hf_quant_config ["quantization" ]
168
-
169
- hf_quant_algo = hf_quant_config .pop ("quant_algo" , None )
170
- if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4" :
171
- raise RuntimeError ("Only FP8 or NVFP4 quantization is supported" )
172
- else :
173
- raise RuntimeError ("No quantization config found" )
181
+ hf_quant_algo = hf_quant_config .get ("quant_algo" , None )
182
+ if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4" :
183
+ raise RuntimeError ("Only FP8 or NVFP4 quantization is supported" )
174
184
175
185
# Iterate through all modules in the model
176
186
for name , module in model .named_modules ():
0 commit comments